当前位置: 首页 > news >正文

用 PyTorch 从零实现 MNIST 手写数字识别

在深度学习入门阶段,MNIST 手写数字识别是经典的 “Hello World” 项目。它不仅能帮助我们熟悉深度学习的核心流程,还能直观理解神经网络如何从数据中学习规律。本文将基于 PyTorch 框架,从零拆解 MNIST 识别模型的实现过程,涵盖数据准备、模型构建、训练优化到性能评估的全流程,并深入讲解每一步背后的原理。

一、环境准备与核心库导入

在开始前,需确保已安装 PyTorch 环境(建议搭配torchvision用于加载数据集)。首先导入所需库,这是构建深度学习模型的基础。

# 导入PyTorch核心库并验证版本
import torch
print(torch.__version__)  # 打印版本号,确保环境配置正确(本文基于2.0+版本测试)# 导入神经网络与数据处理相关模块
from torch import nn  # 神经网络核心层(如Linear、Flatten)
from torch.utils.data import DataLoader  # 数据批量加载工具
from torchvision import datasets  # 计算机视觉常用数据集(含MNIST)
from torchvision.transforms import ToTensor  # 图像格式转换工具

核心库作用解析

  • torch:PyTorch 的主库,提供张量计算、自动微分等核心功能,是所有操作的基础。
  • torch.nn:封装了神经网络的常用组件(如全连接层、激活函数、损失函数),简化模型定义流程。
  • torch.utils.data.DataLoader:将数据集按批次拆分,支持多线程加载,是高效训练的关键工具。
  • torchvision.datasets:提供经典视觉数据集(如 MNIST、CIFAR),支持自动下载,无需手动处理数据文件。
  • torchvision.transforms.ToTensor:将图像从 PIL 格式(或 numpy 数组)转换为 PyTorch 的Tensor格式,并将像素值从0-255归一化到0-1,符合神经网络的输入要求。

二、数据加载与预处理

深度学习的效果依赖高质量数据,MNIST 数据集是手写数字识别的标准数据集,包含 60000 张训练图像和 10000 张测试图像,每张图像为 28×28 像素的灰度图,标签对应 0-9 的数字类别。

1. 加载 MNIST 数据集

# 加载训练集
training_data = datasets.MNIST(root='data',  # 数据存储路径(本地不存在时自动创建)train=True,   # True表示加载训练集,False表示加载测试集download=True,  # 本地无数据时自动从官网下载(约10MB)transform=ToTensor(),  # 数据转换:PIL→Tensor+归一化
)# 加载测试集
test_data = datasets.MNIST(root='data',train=False,  # 加载测试集(用于评估模型泛化能力)download=True,transform=ToTensor(),
)# 验证数据加载效果:打印训练集样本数量
print(f"训练集样本数:{len(training_data)}")  # 输出:60000
print(f"测试集样本数:{len(test_data)}")      # 输出:10000

2. 创建 DataLoader:批量加载数据

直接遍历原始数据集效率低,DataLoader可将数据按批次(batch_size)拆分,同时支持随机打乱(shuffle=True)和多线程加载,大幅提升训练效率。

# 训练集DataLoader:batch_size=64(每次处理64张图像)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# 测试集DataLoader:无需打乱(评估时只需按序计算)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)# 验证DataLoader输出格式
for X, y in test_dataloader:print(f"输入图像X的形状:[N, C, H, W] = {X.shape}")  # 输出:[64, 1, 28, 28]print(f"标签y的形状:[N] = {y.shape}")              # 输出:[64]print(f"标签y的数据类型:{y.dtype}")                # 输出:torch.int64break
维度含义解析
  • X.shape = [64, 1, 28, 28]N=64(批次大小)、C=1(通道数,灰度图为 1,彩色图为 3)、H=28(图像高度)、W=28(图像宽度)。
  • y.shape = [64]:每个样本对应 1 个标签(0-9 的整数),与批次大小一致。

三、选择计算设备(CPU/GPU)

神经网络训练依赖大量矩阵运算,GPU(尤其是 NVIDIA GPU)能大幅加速计算。PyTorch 支持自动检测并选择最优设备,代码如下:

# 优先级:NVIDIA GPU (cuda) → Apple M系列GPU (mps) → CPU
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"使用计算设备:{device}")
  • cuda:适用于 NVIDIA 显卡(需安装 CUDA Toolkit),训练速度比 CPU 快 10-100 倍。
  • mps:适用于 Apple M1/M2 系列芯片,利用苹果自研 GPU 加速。
  • cpu:无专用 GPU 时使用,训练速度较慢)。

四、定义神经网络模型

本文设计一个简单的三层全连接神经网络,结构为:输入层→隐藏层 1→隐藏层 2→输出层。全连接层(nn.Linear)的核心是矩阵乘法,将前一层的特征映射到后一层。

1.模型定义代码

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28 * 28, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)# 前向传播:定义数据在模型中的流动路径def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.out(x)return x# 创建模型实例并移动到指定设备
model = NeuralNetwork().to(device)
print(model)

2.模型结构输出与解析

运行后会打印模型结构:

NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(hidden1): Linear(in_features=784, out_features=128, bias=True)(hidden2): Linear(in_features=128, out_features=256, bias=True)(out): Linear(in_features=256, out_features=10, bias=True)
)
关键组件解析
  1. nn.Flatten:将输入从[N, C, H, W]展平为[N, C×H×W],消除空间维度,适配全连接层的一维输入要求。
  2. nn.Linear(in_features, out_features):全连接层,本质是矩阵乘法:output = x × weight + bias,其中weight(权重)和bias(偏置)是模型需要学习的参数。
    • 例如隐藏层 1:输入 784 维,输出 128 维,对应权重矩阵形状为[784, 128],偏置向量形状为[128]
  3. torch.sigmoid(x):激活函数,将输入值压缩到0-1区间,为模型引入非线性(若没有激活函数,多层全连接等价于单层,无法学习复杂规律)。
  4. 输出层:输出 10 维向量,对应每个数字类别的 “原始分数”(未归一化的概率),后续通过损失函数自动转换为概率。

3. 定义训练与测试函数

训练函数负责 “教模型学习”,测试函数负责 “检验学习效果”:

def train(dataloader, model, loss_fn, optimizer):model.train()  #告诉模型开始训练,pytorch提供2种房市来切换训练和测试模式#训练 model.train()   测试model.eval()batch_count = 1  # 记录批次号for X, y in dataloader:# 数据与模型同设备(否则无法计算)X, y = X.to(device), y.to(device)# 1. 前向传播:计算预测值pred = model.forwarf(X)  # 等价于model.forward(X)# 2. 计算损失(预测值与真实标签的差距)loss = loss_fn(pred, y)# 3. 反向传播与参数更新(核心!)optimizer.zero_grad()  # 清空上一轮梯度(防止累积错误)loss.backward()        # 反向传播:计算各参数的梯度optimizer.step()       # 根据梯度更新参数(权重/偏置)# 监控训练进度:每100批次打印损失if batch_count % 100 == 0:print(f"损失:{loss.item():>7f} | 批次:{batch_count}")batch_count += 1def test(dataloader, model, loss_fn):model.eval()  # 切换到评估模式(禁用Dropout等)total_size = len(dataloader.dataset)  # 测试集总样本数total_batches = len(dataloader)       # 测试集总批次数test_loss, correct = 0, 0             # 总损失、正确预测数# 禁用梯度计算(测试无需更新参数,节省资源)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)# 累加损失test_loss += loss_fn(pred, y).item()# 统计正确数:取预测分数最高的类别(argmax(1))与真实标签比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失与准确率avg_loss = test_loss / total_batchesaccuracy = (correct / total_size) * 100print(f"\n测试结果:")print(f"准确率:{accuracy:>0.1f}% | 平均损失:{avg_loss:>8f}\n")

4. 配置训练参数并执行

# 1. 损失函数:交叉熵损失(适用于多分类,自动含Softmax)
loss_fn = nn.CrossEntropyLoss()
# 2. 优化器:随机梯度下降(SGD),学习率lr=0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 3. 训练轮次:10轮(每轮遍历一次完整训练集)
epochs = 10# 执行训练与测试
print("开始训练:")
for t in range(epochs):print(f"\n===== 第{t+1}轮训练 =====")train(train_dataloader, model, loss_fn, optimizer)
print("训练结束!")# 最终测试模型性能
test(test_dataloader, model, loss_fn)

原代码预期结果:使用 Sigmoid 激活函数,10 轮训练后准确率约 92%-94%,但训练后期损失下降缓慢 —— 这正是 Sigmoid 的缺陷导致的,我们将在下文分析。

五、梯度消失与梯度爆炸:原理与解决方案

在原代码中,Sigmoid 激活函数导致训练后期损失下降缓慢,本质是梯度消失—— 这是深度学习中阻碍模型训练的核心问题之一。我们从原理、现象、解决方案三方面展开。

1. 梯度消失(Vanishing Gradient)

(1)原理:链式法则的 “乘积衰减”

神经网络的梯度通过链式法则反向传播,即:

其中,H1、H2是隐藏层输出,W1是输入层到隐藏层 1 的权重。

以 Sigmoid 为例 Sigmoid 的导数最大值仅为0.25(当 x=0 时)。若网络有 5 层隐藏层,每层导数取 0.25,则梯度经过 5 层传播后:0.25^5 = 0.00097656 梯度已衰减到原来的 1/1000 以下 —— 这就是梯度消失:浅层网络(靠近输入层)的梯度几乎为 0,参数无法更新,模型停止学习。

2. 梯度爆炸(Exploding Gradient)

(1)原理:链式法则的 “乘积放大”

与梯度消失相反,若各层梯度的乘积大于 1,则梯度会指数级增长:

  • 例如,权重初始化过大(如 W 的均值为 2),激活函数导数为 1,则 5 层后梯度为2^5=32,10 层后为2^10=1024;
  • 梯度值过大,会导致参数更新时 “一步跨度过大”,甚至出现 NaN(数值溢出)。

3. 解决方案

替换激活函数为RELU激活函数

ReLU 的数学定义与导数特性
1. ReLU 的数学公式

ReLU 是一个极其简单的分段函数,定义为: ReLU(x)=max(0,x)

  • 当输入x>=0时,y = x;
  • 当输入x < 0 时,y = 0

导数为:

  • 当输入x>0时,y '= 1;
  • 当输入x < 0 时,y '= 0
(1)x > 0 时,导数恒为1——避免梯度消失

当隐藏层神经元的输入 x > 0 时,ReLU的导数为1。此时,梯度在反向传播过程中:

  • 各层激活函数导数的乘积为 1*1*1*1....*1=1;
  • 梯度不会因为“多层传播”而衰减,浅层参数(如输入层→第1隐藏层的权重)能获得有效的梯度更新。
(2)导数范围被限制在{0, 1}——避免梯度爆炸

梯度爆炸的核心是“梯度乘积大于1,导致指数级增长”,而ReLU的导数只有两个可能值:0或1。无论网络有多少层,各层导数的乘积最大为1,而ReLU的导数只有两个可能值:0或1。无论网络有多少层,各层导数的乘积最大为1。

代码示例:

    def forward(self,x):x = self.flatten(x)x = self.hidden1(x)x = torch.relu(x)x = self.hidden2(x)x = torch.relu(x)x = self.out(x)return x

relu激活函数可能在比较少的网络层中作用不大,但是当网络层数量大的时候,relu函数可以体现出较为重要的作用。

http://www.dtcms.com/a/350430.html

相关文章:

  • 微论-神经网络中记忆的演变
  • volatile关键字:防止寄存器操作被优化
  • Java设计模式-装饰器模式:从“咖啡加料”到Java架构
  • 动态线程池核心解密:从 Nacos 到 Pub/Sub 架构的实现与对比
  • 使用百度统计来统计浏览量
  • 网易算法岗位--面试真题分析
  • 江苏安全员 A 证 “安全生产管理” 核心考点
  • 【笔记】Roop 之 NSFW 检测屏蔽测试
  • 电池分选机:破解电池性能一致性难题的自动化方案|深圳比斯特
  • 【车载开发系列】ParaSoft集成测试环境配置(五)
  • Seaborn数据可视化实战:Seaborn数据可视化实战入门
  • 我的小灶坑
  • 使用 gemini 来分析 github 项目
  • 【Day 33】Linux-Mysql日志
  • Linux 系统内存不足导致服务崩溃的排查方法
  • 跨站脚本攻击(XSS)分类介绍及解决办法
  • 单北斗变形监测系统应用维护指南
  • 59 C++ 现代C++编程艺术8-智能指针
  • 探索量子计算的新前沿
  • 深度学习之第三课PyTorch( MNIST 手写数字识别神经网络模型)
  • Telematics Control Unit(TCU)的系统化梳理
  • 从零开始学习单片机14
  • Fory序列化与反序列化
  • 以正确方式构建AI Agents:Agentic AI的设计原则
  • 技术速递|使用 AI 应用模板扩展创建一个 .NET AI 应用与自定义数据进行对话
  • 【Hadoop】HDFS 分布式存储系统
  • Nuxt.js@4 中管理 HTML <head> 标签
  • 【二叉树 - LeetCode】236. 二叉树的最近公共祖先
  • TAISAW钛硕|TST嘉硕Differential output Crystal Oscillator - TW0692AAAE40
  • [electron]开发环境驱动识别失败