当前位置: 首页 > news >正文

使用 PyTorch 实现 MNIST 手写数字识别

一、背景说明

本实例通过MNIST 手写数字识别任务,演示如何利用 PyTorch 构建并训练神经网络。

环境要求:PyTorch 1.5+、支持 GPU/CPU;

数据集:MNIST(手写数字数据集,含 0-9 共 10 类,图像尺寸 28×28);

核心流程:数据下载与预处理→源数据可视化→模型构建→损失函数与优化器定义→模型训练→训练结果可视化。

二、数据准备

1. 导入依赖库

python

运行

import numpy as np
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

2. 定义超参数

python

运行

train_batch_size = 64   # 训练批次大小
test_batch_size = 128   # 测试批次大小
learning_rate = 0.01    # 初始学习率
num_epochs = 20         # 训练轮数

3. 数据预处理与加载

通过transforms做 “张量转换 + 标准化”,再用DataLoader生成可迭代的数据加载器。

python

运行

# 组合预处理操作:转张量 + 标准化(均值/标准差为0.5)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])# 下载并加载MNIST训练集、测试集
train_dataset = MNIST('../data/', train=True, transform=transform, download=True)
test_dataset = MNIST('../data/', train=False, transform=transform)# 创建数据加载器(训练集打乱,测试集不打乱)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

三、源数据可视化

matplotlib可视化测试集样本,直观观察数据。

python

运行

import matplotlib.pyplot as plt
%matplotlib inline  # Jupyter中 inline 显示图像# 从测试加载器取一个批次样本
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)# 绘制6个样本及真实标签
fig = plt.figure()
for i in range(6):plt.subplot(2, 3, i+1)plt.tight_layout()plt.imshow(example_data[i][0], cmap='gray', interpolation='none')  # 灰度显示plt.title(f'Ground Truth: {example_targets[i]}')  # 标注真实标签plt.xticks([]); plt.yticks([])  # 隐藏坐标轴

四、构建神经网络模型

定义继承nn.Module的模型类Net,通过nn.Sequential组合层,实现前向传播。

python

运行

class Net(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Net, self).__init__()self.flatten = nn.Flatten()  # 展平层:28×28 → 784# 隐藏层1:线性变换 + 批归一化self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))# 隐藏层2:线性变换 + 批归一化self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))# 输出层:线性变换self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = self.flatten(x)  # 展平输入x = F.relu(self.layer1(x))  # 隐藏层1 + ReLU激活x = F.relu(self.layer2(x))  # 隐藏层2 + ReLU激活x = F.softmax(self.out(x), dim=1)  # 输出层 + Softmax(按行归一化)return x

五、模型实例化与优化器定义

python

运行

# 自动选择设备(优先GPU,否则CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 实例化模型(输入784=28×28,隐藏层300/100,输出10类)
model = Net(28 * 28, 300, 100, 10)
model.to(device)  # 模型移到指定设备# 损失函数(交叉熵)+ 优化器(SGD+动量)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

六、训练模型

训练阶段(更新参数)和验证阶段(评估性能),同时记录损失与准确率。

python

运行

losses = []          # 训练损失
acces = []           # 训练准确率
eval_losses = []     # 验证损失
eval_acces = []      # 验证准确率
writer = SummaryWriter(log_dir='logs', comment='train-loss')  # TensorBoard记录for epoch in range(num_epochs):# ---------- 训练阶段 ----------train_loss, train_acc = 0, 0model.train()  # 开启训练模式(批归一化、Dropout生效)# 动态调整学习率:每5个epoch,学习率×0.9if epoch % 5 == 0:optimizer.param_groups[0]['lr'] *= 0.9print(f"学习率: {optimizer.param_groups[0]['lr']:.6f}")for img, label in train_loader:img, label = img.to(device), label.to(device)  # 数据移到设备# 正向传播out = model(img)loss = criterion(out, label)# 反向传播 + 参数更新optimizer.zero_grad()  # 清零梯度loss.backward()        # 损失回传optimizer.step()       # 更新参数# 累加训练损失train_loss += loss.item()writer.add_scalar('Train', train_loss / len(train_loader), epoch)  # 记录到TensorBoard# 累加训练准确率_, pred = out.max(1)  # 取概率最大的索引为预测类别num_correct = (pred == label).sum().item()train_acc += num_correct / img.shape[0]# 保存当前epoch的训练损失/准确率(取平均)losses.append(train_loss / len(train_loader))acces.append(train_acc / len(train_loader))# ---------- 验证阶段 ----------eval_loss, eval_acc = 0, 0model.eval()  # 开启评估模式(批归一化、Dropout不生效)with torch.no_grad():  # 关闭梯度计算(加速+省内存)for img, label in test_loader:img, label = img.to(device), label.to(device)img = img.view(img.size(0), -1)  # 展平图像(匹配模型输入)out = model(img)loss = criterion(out, label)# 累加验证损失eval_loss += loss.item()# 累加验证准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()eval_acc += num_correct / img.shape[0]# 保存当前epoch的验证损失/准确率(取平均)eval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))# 打印训练&验证结果print(f'Epoch: {epoch}, Train Loss: {train_loss/len(train_loader):.4f}, 'f'Train Acc: {train_acc/len(train_loader):.4f}, 'f'Test Loss: {eval_loss/len(test_loader):.4f}, 'f'Test Acc: {eval_acc/len(test_loader):.4f}')

七、训练结果可视化

训练损失曲线为例,观察模型收敛情况:

python

运行

import matplotlib.pyplot as pltplt.title('Train Loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

曲线显示:训练损失随 epoch 增加逐渐下降,模型持续学习并收敛。

关键知识点总结

  1. 数据处理transforms做预处理,DataLoader实现批量加载与打乱。
  2. 模型构建:继承nn.Module,用nn.Sequential组合层,自定义forward实现前向传播。
  3. 训练逻辑:区分train()/eval()模式,反向传播更新参数,同时记录损失与准确率。
  4. 可视化工具matplotlib画曲线,SummaryWriter结合 TensorBoard 做更丰富的可视化。
  5. 训练技巧:动态调整学习率、SGD + 动量加速收敛、批归一化(BatchNorm)稳定训练。
http://www.dtcms.com/a/477442.html

相关文章:

  • ComfyUI安装和启动攻略1
  • h5移动端开发民治网站优化培训
  • uniapp 微信小程序蓝牙接收中文乱码
  • 多制式基站综合测试线的架构与验证实践 (1)
  • Ceph 分布式存储学习笔记(四):文件系统存储管理
  • ceph设置标志位
  • 系统升级丨让VR全景制作更全面、更简单
  • PyTorch 实现 MNIST 手写数字识别全流程
  • PyTorch实现MNIST手写数字识别:从数据到模型全解析
  • PostgreSQL 测试磁盘性能
  • 北京网站开发科技企业网站
  • 干货|腾讯 Linux C/C++ 后端开发岗面试
  • 【深度学习新浪潮】如何入门分布式大模型推理?
  • 基于单片机的螺旋藻生长大棚PH智能控制设计
  • 分布式专题——42 MQ常见问题梳理
  • mapbox基础,使用矢量切片服务(pbf)加载symbol符号图层
  • Linux中setup_arch和setup_memory相关函数的实现
  • 智能合约在分布式密钥管理系统中的应用
  • Spark大数据分析与实战笔记(第六章 Kafka分布式发布订阅消息系统-01)
  • 做网络竞拍的网站需要什么厦门网站设计哪家公司好
  • React Native:从react的解构看编程众多语言中的解构
  • C++ 手写 List 容器实战:从双向链表原理到完整功能落地,附源码与测试验证
  • 化工课设代做网站网络宣传网站建设价格
  • 【第1篇】2025年羊城工匠杯nl2sql比赛介绍
  • 2025年ASP.NETMVC面试题库全解析
  • 机器学习:支持向量机
  • C 标准库 - `<locale.h>`
  • YOLO系列——Ubuntu20.04下通过conda虚拟环境安装Labelme
  • 流量安全优化:基于 Sentinel 实现网站流量控制和熔断
  • Ansible 自动化部署K8S1.34.1