当前位置: 首页 > news >正文

使用pytorch创建/训练/推理OCR模型

一、任务描述

        从手写数字图像中自动识别出对应的数字(0-9)” 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)

        1、任务的核心定义:输入与输出

  • 输入:28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 “3” 的形状,就是模型的输入。
  • 输出:一个 “类别标签”,即从 10 个可能的类别(0、1、2、…、9)中选择一个,作为输入图像对应的数字,例如:输入 “3” 的图像,模型输出 “类别 3”,即完成一次正确识别。
  • 目标:让模型在 “未见的手写数字图像” 上,尽可能准确地输出正确类别(通常用 “准确率” 衡量,即正确识别的图像数 / 总图像数)

        2、任务的核心挑战

  • 不同人书写习惯差异极大:有人写的 “4” 带弯钩,有人写的 “7” 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 “5”,可能是 “直笔 5”“圆笔 5”,也可能是倾斜 10° 或 20° 的 “5”—— 模型需要忽略这些 “风格差异”,抓住 “数字的本质特征”(如 “5 有一个上半圆 + 一个竖线”)。
  • 图像噪声与干扰:手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 “0” 的图像,边缘有一小块污渍,模型需要判断 “这是噪声” 而不是 “0 的一部分”,避免误判为 “6” 或 “8”。

二、模型训练

       1、MNIST数据集

        MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 “基准数据集”,MNIST手写数字识别的核心是 “让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字”,它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。

  • 数据量适中:包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
  • 图像规格统一:所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
  • 标注准确:每张图像都有明确的 “正确数字标签”(人工标注),无需额外标注成本。

        2、代码

  • 数据准备:使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;
  • 模型定义:定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);
  • 训练设置:使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;
  • 训练过程:循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 1. 数据准备
# 定义数据变换
transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化,MNIST数据集的均值和标准差
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',  # 数据保存路径train=True,  # 训练集download=True,  # 如果数据不存在则下载transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,  # 测试集download=True,transform=transform
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 输入层到隐藏层self.fc1 = nn.Linear(28 * 28, 128)  # MNIST图像大小为28x28# 隐藏层到输出层self.fc2 = nn.Linear(128, 10)  # 10个类别(0-9)def forward(self, x):# 将图像展平为一维向量x = x.view(-1, 28 * 28)# 隐藏层,使用ReLU激活函数x = torch.relu(self.fc1(x))# 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()  # 交叉熵损失,适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 4. 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 设置为训练模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(data)loss = criterion(outputs, target)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()# 每100个批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 6. 运行训练和测试
if __name__ == '__main__':# 训练模型print("开始训练模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)print("模型训练完成...")# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存为 mnist_model.pth")

三、模型使用测试

import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms  # 修正transforms的导入方式# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 加载模型
def load_model(model_path='mnist_model.pth'):model = SimpleNN()# 加载模型时添加参数以避免潜在的Python 3兼容性问题model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))model.eval()  # 设置为评估模式return model# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):# 打开图像并转换为灰度图img = Image.open(image_path).convert('L')  # 'L'表示灰度模式# 调整大小为28x28img = img.resize((28, 28))# 转换为numpy数组并归一化img_array = np.array(img) / 255.0# 定义图像转换(使用torchvision的transforms)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 注意:这里需要先将numpy数组转换为PIL图像再应用transformimg_pil = Image.fromarray((img_array * 255).astype(np.uint8))img_tensor = transform(img_pil).unsqueeze(0)  # 增加批次维度return img_tensor# 预测函数
def predict_digit(model, image_path):# 预处理图像img_tensor = preprocess_image(image_path)# 预测with torch.no_grad():  # 不计算梯度outputs = model(img_tensor)_, predicted = torch.max(outputs.data, 1)return predicted.item()  # 返回预测的数字# 示例使用
if __name__ == '__main__':# 加载模型model = load_model('mnist_model.pth')# 预测示例图像test_image_path = 'test_digit.png'  # 用户需要提供的测试图像路径try:predicted_digit = predict_digit(model, test_image_path)print(f"预测的数字是: {predicted_digit}")except Exception as e:print(f"预测出错: {str(e)}")

使用gpu0(第一块gpu)进行训练/推理:
        torch.cuda.set_device(0)    
        model = model.cuda(0)
使用cpu记性训练/推理:
        model = model.cpu()


怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制

YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析


文章转载自:

http://GMMCudRS.fkdts.cn
http://0nnuRN0M.fkdts.cn
http://HMgDwqDc.fkdts.cn
http://Y51uorL2.fkdts.cn
http://lqTGsuOK.fkdts.cn
http://EBmYrpla.fkdts.cn
http://s9oMDjMo.fkdts.cn
http://po6fojQ5.fkdts.cn
http://QN2R5H7e.fkdts.cn
http://C6GWFZeI.fkdts.cn
http://ilMHsUzi.fkdts.cn
http://xoAFVwc0.fkdts.cn
http://Web357rs.fkdts.cn
http://TJX0Hf3v.fkdts.cn
http://c8UXzgBz.fkdts.cn
http://cF21YJk0.fkdts.cn
http://BsKvd3mx.fkdts.cn
http://PegrvqhT.fkdts.cn
http://VguZnUaj.fkdts.cn
http://BYTELPrV.fkdts.cn
http://QndAVOr4.fkdts.cn
http://KSJYIMDG.fkdts.cn
http://FCkmRTIx.fkdts.cn
http://yvzyTmhD.fkdts.cn
http://eOk4GaU1.fkdts.cn
http://Bt8Llnqd.fkdts.cn
http://bZtvaeYY.fkdts.cn
http://QX5Q6wA1.fkdts.cn
http://j5RvGUXh.fkdts.cn
http://nRptRgzp.fkdts.cn
http://www.dtcms.com/a/369041.html

相关文章:

  • 从音频到文本实现高精度离线语音识别
  • 安防芯片ISP白平衡统计数据如何提升场景适应性?
  • Spring如何解决循环依赖:深入理解三级缓存机制
  • 当服务器出现网卡故障时如何检测网卡硬件故障并解决?
  • 【算法--链表】83.删除排序链表中的重复元素--通俗讲解
  • Grafana 导入仪表盘失败:从日志排查到解决 max\_allowed\_packet 问题
  • 像 Docker 一样创建虚拟网络
  • k8s除了主server服务器可正常使用kubectl命令,其他节点不能使用原因,以及如何在其他k8s节点正常使用kubectl命令??
  • xwiki sql注入漏洞复现(CVE-2025-32969)
  • MySQL】从零开始了解数据库开发 --- 表的操作
  • 「数据获取」《中国劳动统计年鉴》(1991-2024)
  • 手把手教你用Vue3+TypeScript+Vite搭建后台管理系统
  • oracle 使用CONNECT BY PRIOR 提示ORA-01436
  • 【数据分享】土地利用矢量shp数据分享-甘肃
  • PHP:驱动现代Web应用发展的核心力量
  • Vue项目API代理配置与断点调试
  • 永磁同步电机控制算法--传统IF控制结合滑模观测器的无感控制策略
  • 辗转相除法(欧几里得算法)的证明
  • 【MySQL索引设计实战:从入门到精通的高性能索引】
  • 《嵌入式硬件(三):串口通信》
  • python库 Py2exe 的详细使用(将 Python 脚本变为Windows独立软件包)
  • 激光雷达与IMU时间硬件同步与软件同步区分
  • 《基于stm32的智慧家居基础项目》
  • Docker在Windows与Linux系统安装的一体化教学设计
  • sub3G和sub6G的区别和联系
  • 【存储选型终极指南】RustFS vs MinIO:5大维度深度对决,95%技术团队的选择秘密!
  • 【Python基础】 18 Rust 与 Python print 函数完整对比笔记
  • Rust Axum 快速上手指南(静态网页和动态网页2024版)
  • CVPR 2025|无类别词汇的视觉-语言模型少样本学习
  • 9月14日 不见不散|2025年华望M-Design v2软件线上发布会