当前位置: 首页 > news >正文

PyTorch实现MNIST手写数字识别:从数据到模型全解析

本文将带你完整实现一个基于PyTorch的MNIST手写数字识别模型,包含数据加载、网络构建、训练优化和评估全流程。


1.数据加载与预处理

MNIST数据集包含6万张28×28像素的手写数字灰度图,我们使用PyTorch内置工具进行加载:

# 数据预处理:归一化到[-1,1]范围
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])# 加载数据集
train_dataset = mnist.MNIST('../data/', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('../data/', train=False, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

数据可视化展示样本分布:

examples = enumerate(test_loader)
_, (data, targets) = next(examples)plt.figure(figsize=(10,6))
for i in range(12):plt.subplot(3,4,i+1)plt.imshow(data[i][0], cmap='gray')plt.title(f"Label: {targets[i]}")plt.axis('off')
plt.tight_layout()

2. 神经网络模型设计

我们构建一个包含两个隐藏层的全连接网络,使用批归一化加速收敛:

class DigitRecognizer(nn.Module):def __init__(self, input_size, hidden1, hidden2, output_size):super().__init__()self.flatten = nn.Flatten()self.layer1 = nn.Sequential(nn.Linear(input_size, hidden1),nn.BatchNorm1d(hidden1))self.layer2 = nn.Sequential(nn.Linear(hidden1, hidden2),nn.BatchNorm1d(hidden2))self.out = nn.Linear(hidden2, output_size)def forward(self, x):x = self.flatten(x)x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return F.softmax(self.out(x), dim=1)

网络结构说明

  • 输入层:784个神经元(28×28展平)
  • 隐藏层1:300个神经元 + 批归一化
  • 隐藏层2:100个神经元 + 批归一化
  • 输出层:10个神经元(对应0-9数字)
  • 激活函数:ReLU
  • 输出处理:Softmax归一化概率

3. 模型训练与优化

采用带动量的随机梯度下降(SGD)优化器,配合学习率衰减策略:

# 初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DigitRecognizer(784, 300, 100, 10).to(device)# 损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 训练循环
for epoch in range(20):# 每5轮衰减学习率if epoch % 5 == 0:optimizer.param_groups[0]['lr'] *= 0.9print(f"Epoch {epoch}: LR={optimizer.param_groups[0]['lr']:.6f}")# 训练阶段model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 验证阶段model.eval()with torch.no_grad():# 计算验证集准确率correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)acc = 100 * correct / totalprint(f"Epoch {epoch}: Test Acc = {acc:.2f}%")

关键技术点

  1. 交叉熵损失函数:$$ \mathcal{L} = -\sum_{i=1}^{N} y_i \log(\hat{y}_i) $$
  2. 动量优化:$$ v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta) $$
  3. 学习率衰减:每5轮学习率乘以0.9
  4. 批归一化:加速训练并提高泛化能力

4. 训练结果分析

经过20轮训练,模型在测试集上达到98%+的准确率:

Epoch 0: Test Acc = 96.32%
Epoch 5: Test Acc = 97.86% (LR=0.008100)
Epoch 10: Test Acc = 98.12% (LR=0.007290)
Epoch 15: Test Acc = 98.24% (LR=0.006561)
Epoch 19: Test Acc = 98.37%

性能优化建议

  1. 尝试卷积神经网络(CNN)提升特征提取能力
  2. 增加数据增强(旋转、平移等)
  3. 使用更先进的优化器(Adam, RMSProp)
  4. 引入Dropout防止过拟合

5. 模型部署与应用

训练好的模型可保存并用于实时识别:

# 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')# 加载模型进行预测
loaded_model = DigitRecognizer(784, 300, 100, 10)
loaded_model.load_state_dict(torch.load('mnist_model.pth'))
loaded_model.eval()# 单样本预测
test_image = test_dataset[0][0].unsqueeze(0)
prediction = loaded_model(test_image)
print(f"预测数字: {torch.argmax(prediction)}")


完整代码已上传至GitHub:项目链接
通过本实现,你已掌握PyTorch图像分类的核心流程,可扩展应用于更复杂的计算机视觉任务!

http://www.dtcms.com/a/477433.html

相关文章:

  • PostgreSQL 测试磁盘性能
  • 北京网站开发科技企业网站
  • 干货|腾讯 Linux C/C++ 后端开发岗面试
  • 【深度学习新浪潮】如何入门分布式大模型推理?
  • 基于单片机的螺旋藻生长大棚PH智能控制设计
  • 分布式专题——42 MQ常见问题梳理
  • mapbox基础,使用矢量切片服务(pbf)加载symbol符号图层
  • Linux中setup_arch和setup_memory相关函数的实现
  • 智能合约在分布式密钥管理系统中的应用
  • Spark大数据分析与实战笔记(第六章 Kafka分布式发布订阅消息系统-01)
  • 做网络竞拍的网站需要什么厦门网站设计哪家公司好
  • React Native:从react的解构看编程众多语言中的解构
  • C++ 手写 List 容器实战:从双向链表原理到完整功能落地,附源码与测试验证
  • 化工课设代做网站网络宣传网站建设价格
  • 【第1篇】2025年羊城工匠杯nl2sql比赛介绍
  • 2025年ASP.NETMVC面试题库全解析
  • 机器学习:支持向量机
  • C 标准库 - `<locale.h>`
  • YOLO系列——Ubuntu20.04下通过conda虚拟环境安装Labelme
  • 流量安全优化:基于 Sentinel 实现网站流量控制和熔断
  • Ansible 自动化部署K8S1.34.1
  • 1. 使用VSCode开发uni-app环境搭建
  • Docker监控:cAdvisor+Prometheus+Grafana实战指南
  • Redis-持久化之AOF
  • Python Redis 教程
  • R语言绘制热图
  • GPU微架构
  • Vue-- Axios 交互(二)
  • 中煤浙江基础建设有限公司网站曹妃甸网站建设
  • phpcms做汽车网站wordpress如何关注博客