当前位置: 首页 > news >正文

用 PyTorch 搭建 CNN 实现 MNIST 手写数字识别

在图像识别领域,卷积神经网络(CNN) 凭借其对空间特征的高效提取能力,成为手写数字识别、人脸识别等任务的首选模型。而 MNIST(手写数字数据集)作为入门级数据集,几乎是每个深度学习学习者的 “第一个项目”。

本文将带大家从零开始,用 PyTorch 搭建一个 CNN 模型完成 MNIST 手写数字识别任务,不仅会贴出完整代码,还会逐行解析核心逻辑,帮你搞懂 “每个参数为什么这么设”“每一层的作用是什么”,即使是刚接触 PyTorch 的新手也能轻松跟上。

一、前置知识与环境准备

在开始前,我们需要先明确两个核心背景,以及搭建好运行环境:

1. 核心背景速览

  • MNIST 数据集:包含 70000 张 28×28 像素的灰度手写数字图片(0-9),其中 60000 张为训练集,10000 张为测试集,每张图片对应一个 “数字类别” 标签(0-9)。
  • CNN 为什么适合?:相比全连接神经网络,CNN 通过 “卷积层提取局部特征(边缘、纹理)+ 池化层下采样”,能大幅减少参数数量、避免过拟合,同时更好地保留图像的空间结构信息。

2. 环境准备

需要安装 PyTorch 和 TorchVision(PyTorch 官方的计算机视觉库,内置 MNIST 数据集):

# pip安装命令(根据系统自动匹配版本,若需指定CUDA版本可参考PyTorch官网)
pip install torch torchvision

验证环境是否安装成功:

import torch
print(torch.__version__)  # 输出PyTorch版本,如2.0.1
print(torch.cuda.is_available())  # 输出True表示支持GPU加速(需NVIDIA显卡)

二、完整代码先行(可直接运行)

先贴出完整可运行的代码,后面会逐段拆解解析:

注意:nn.Sequential()是将网络层组合在一起,内部不能写函数

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 1. 加载MNIST数据集
train_data = datasets.MNIST(root='data',          # 数据保存路径train=True,           # 加载训练集download=True,        # 若路径下无数据则自动下载transform=ToTensor()  # 将图像转为Tensor(0-1归一化+维度调整:(H,W,C)→(C,H,W))
)
test_data = datasets.MNIST(root='data',train=False,          # 加载测试集download=True,transform=ToTensor()
)# 2. 数据加载器(分批处理数据)
train_loader = DataLoader(train_data, batch_size=64)  # 每批64个样本
test_loader = DataLoader(test_data, batch_size=64)# 3. 设备配置(优先GPU,其次CPU)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')  # 打印当前使用的设备# 4. 定义CNN模型
class CNN(nn.Module):def __init__(self):super().__init__()# 卷积块1:输入(1,28,28) → 输出(8,14,14)self.conv1 = nn.Sequential(# 卷积层:1个输入通道→8个输出通道,卷积核5×5,步长1,填充2nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2),nn.ReLU(),  # 激活函数(引入非线性)nn.MaxPool2d(kernel_size=2)  # 池化层:2×2下采样,尺寸减半)# 卷积块2:输入(8,14,14) → 输出(32,7,7)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2),  # 8→16通道,其他参数同上nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2),  # 16→32通道nn.ReLU(),nn.MaxPool2d(2)  # 下采样后尺寸14→7)# 卷积块3:输入(32,7,7) → 输出(64,7,7)(无池化,保留尺寸)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 32→64通道nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),  # 64→64通道(加深特征提取)nn.ReLU())# 全连接层:输入(64×7×7) → 输出10(对应10个数字类别)self.out = nn.Linear(64 * 7 * 7, 10)# 前向传播(定义数据在模型中的流动路径)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # 展平:(batch_size, 64,7,7) → (batch_size, 64×7×7)output = self.out(x)return output# 5. 初始化模型并移至指定设备
model = CNN().to(device)
print(model)  # 打印模型结构,验证是否正确# 6. 定义训练函数
def train(dataloader, model, loss_fn, optimizer):model.train()  # 启用训练模式(如BatchNorm、Dropout会生效)batch_count = 1  # 计数批次,用于打印日志for X, y in dataloader:# 将数据移至指定设备(GPU/CPU)X, y = X.to(device), y.to(device)# 前向传播:计算模型预测值pred = model(X)# 计算损失(多分类任务用CrossEntropyLoss)loss = loss_fn(pred, y)# 反向传播:更新模型参数optimizer.zero_grad()  # 清空上一轮梯度(避免累积)loss.backward()        # 计算梯度(反向传播)optimizer.step()       # 根据梯度更新参数(优化器执行)# 每100个批次打印一次损失(监控训练进度)if batch_count % 100 == 0:loss_value = loss.item()  # 取出损失值(脱离计算图)print(f'Batch: {batch_count:>4} | Loss: {loss_value:>6.4f}')batch_count += 1# 7. 定义测试函数
def test(dataloader, model, loss_fn):model.eval()  # 启用评估模式(关闭BatchNorm、Dropout)total_samples = len(dataloader.dataset)  # 测试集总样本数correct = 0  # 正确预测的样本数total_loss = 0  # 总损失# 禁用梯度计算(测试阶段无需更新参数,节省内存)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累积损失和正确数total_loss += loss_fn(pred, y).item()# pred.argmax(1):取每行最大概率的索引(即预测类别),与y比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率avg_loss = total_loss / len(dataloader)  # len(dataloader) = 总批次accuracy = (correct / total_samples) * 100  # 准确率(百分比)print(f'\nTest Result | Accuracy: {accuracy:>5.2f}% | Avg Loss: {avg_loss:>6.4f}\n')# 8. 配置训练参数并执行
loss_fn = nn.CrossEntropyLoss()  # 多分类交叉熵损失(内置Softmax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器,学习率0.001
epochs = 10  # 训练轮次(整个训练集遍历10次)# 循环训练+测试
for epoch in range(epochs):print(f'=================== Epoch {epoch + 1}/{epochs} ===================')train(train_loader, model, loss_fn, optimizer)  # 训练一轮test(test_loader, model, loss_fn)  # 测试一轮print("Training Finished!")

三、核心代码逐段解析

上面的代码看似长,但逻辑很清晰,我们按 “数据→模型→训练→测试” 的流程拆解核心部分。

1. 数据加载与预处理

MNIST 数据集的加载全靠torchvision.datasets.MNIST,无需手动下载和解析,非常方便。关键参数解析:

  • root='data':数据会保存在当前目录的data文件夹下(自动创建);
  • train=True/FalseTrue加载 6 万张训练集,False加载 1 万张测试集;
  • transform=ToTensor():这是核心预处理步骤,作用有两个:
    1. 将图像从 “PIL 格式(0-255 像素值)” 转为 “Tensor 格式(0-1 归一化值)”,避免大数值导致梯度爆炸;
    2. 调整维度:从图像默认的(高度H, 宽度W, 通道C)转为 PyTorch 要求的(通道C, 高度H, 宽度W)(MNIST 是灰度图,C=1)。

然后用DataLoader将数据集分批:

  • batch_size=64:每次训练取 64 个样本计算梯度(batch_size 越大,训练越稳定,但内存占用越高);
  • DataLoader会自动打乱训练集(默认shuffle=True),避免模型学习到 “样本顺序” 的无关特征。

2. 设备配置:GPU 加速有多重要?

代码中这行是 “硬件适配” 的关键:

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
  • cuda:NVIDIA 显卡的 GPU 加速(训练 10 轮可能只需 1-2 分钟);
  • mps:苹果芯片(M1/M2)的 GPU 加速;
  • cpu:默认选项(训练 10 轮可能需要 10-20 分钟,速度慢很多)。

后续通过model.to(device)X.to(device),将模型和数据都移到指定设备上,确保计算在同一设备进行(否则会报错)。

3. CNN 模型搭建(核心中的核心)

我们定义的CNN类继承自nn.Module(PyTorch 所有模型的基类),核心是__init__(定义层)和forward(定义数据流动)。

先看模型结构总览

输入(1,28,28) → 卷积块1 → 输出(8,14,14) → 卷积块2 → 输出(32,7,7) → 卷积块3 → 输出(64,7,7) → 展平 → 全连接层 → 输出(10)
(1)卷积层参数解析

conv1的第一个卷积层为例:

nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2)
  • in_channels=1:输入通道数(MNIST 是灰度图,所以 1);
  • out_channels=8:输出通道数 = 卷积核数量(8 个卷积核,提取 8 种不同特征);
  • kernel_size=5:卷积核大小(5×5 的窗口,比 3×3 能提取更复杂的局部特征);
  • stride=1:卷积核每次滑动 1 个像素(步长越小,特征保留越完整);
  • padding=2:填充(在图像边缘补 2 个像素),目的是让卷积后图像尺寸不变:
    👉 尺寸计算公式:输出尺寸 = (输入尺寸 - 卷积核尺寸 + 2×padding) / stride + 1
    👉 代入:(28 - 5 + 2×2)/1 + 1 = 28,所以卷积后还是 28×28。
(2)激活函数 ReLU

每个卷积层后都加nn.ReLU(),作用是引入非线性

  • 没有激活函数的话,多个卷积层叠加还是线性变换,无法拟合复杂数据;
  • ReLU 的公式:ReLU(x) = max(0, x),计算简单、梯度不易消失,是目前最常用的激活函数。
(3)池化层 MaxPool2d

nn.MaxPool2d(kernel_size=2)是 2×2 最大池化,作用是下采样

  • 尺寸减半:28×28→14×14,14×14→7×7,大幅减少后续计算量;
  • 保留关键特征:取 2×2 窗口的最大值,相当于 “强化局部最显著的特征”,提高模型鲁棒性。
(4)展平与全连接层

卷积块 3 输出的是(batch_size, 64, 7, 7)的张量(batch_size 是每批样本数),需要用x.view(x.size(0), -1)展平为(batch_size, 64×7×7)的一维向量,才能输入全连接层:

  • x.size(0):获取 batch_size(确保展平后每一行对应一个样本);
  • -1:让 PyTorch 自动计算剩余维度(64×7×7=3136);
  • 全连接层nn.Linear(3136, 10):将 3136 维特征映射到 10 维(对应 0-9 的 10 个类别)。

4. 训练函数:模型如何 “学习”?

训练的核心是 “前向传播算损失→反向传播求梯度→优化器更新参数” 的循环:

  1. model.train():启用训练模式(比如如果模型有 BatchNorm,会计算当前批次的均值和方差);
  2. 前向传播:pred = model(X),用当前模型参数计算预测值;
  3. 计算损失:loss = loss_fn(pred, y),用CrossEntropyLoss(多分类任务专用,内置了 Softmax,无需手动在模型输出加 Softmax);
  4. 反向传播:
    • optimizer.zero_grad():清空上一轮的梯度(如果不清空,梯度会累积,导致参数更新错误);
    • loss.backward():自动计算所有可训练参数的梯度(PyTorch 的自动微分机制);
    • optimizer.step():用计算出的梯度更新参数(Adam 优化器会自适应调整学习率,比 SGD 收敛更快)。

5. 测试函数:模型学得怎么样?

测试阶段不需要更新参数,核心是计算 “准确率” 和 “平均损失”:

  1. model.eval():启用评估模式(关闭 BatchNorm 的批次统计更新、关闭 Dropout);
  2. with torch.no_grad():禁用梯度计算(节省内存,加速测试);
  3. 准确率计算:pred.argmax(1) == y,比较预测类别和真实类别,求和后除以总样本数。

四、预期结果与优化方向

1. 预期训练结果

在 GPU 上训练 10 轮后,通常能达到:

  • 测试准确率:98.5% 以上(甚至 99%);
  • 测试平均损失:0.04 以下。

训练过程中,损失会逐渐下降,准确率会逐渐上升(如果出现损失不下降或准确率波动,可能是学习率太大或 batch_size 太小)。

2. 模型优化方向

如果想进一步提升性能,可以尝试这些改进:

  1. 增加 Dropout 层:在卷积层或全连接层后加nn.Dropout(0.2),随机 “关闭” 20% 的神经元,防止过拟合;
  2. 使用学习率调度:比如torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5),每 5 轮将学习率减半,后期精细调整;
  3. 加深网络:增加卷积块数量(比如再加一个 conv4),或增加每个卷积层的输出通道数;
  4. 数据增强:用torchvision.transforms添加旋转、平移、缩放等操作,比如:
    transform = transforms.Compose([transforms.RandomRotation(5),  # 随机旋转±5度transforms.ToTensor()
    ])

        数据增强能让模型看到更多 “变种” 样本,提升泛化能力。

五、总结

本文用 PyTorch 实现了一个基础的 CNN 模型,完成了 MNIST 手写数字识别任务,核心收获包括:

  1. 掌握了 PyTorch 加载数据集、搭建 CNN 模型的基本流程;
  2. 理解了卷积层、池化层、激活函数的作用和参数意义;
  3. 熟悉了 “训练 - 测试” 的循环逻辑,以及 GPU 加速的配置方法。

MNIST 是入门任务,但 CNN 的核心思想(特征提取 + 下采样)可以迁移到更复杂的图像任务(如 CIFAR-10、ImageNet)。建议大家动手修改代码,比如调整卷积核大小、学习率、网络层数,观察结果变化,这样才能真正理解每个参数的影响~

http://www.dtcms.com/a/354526.html

相关文章:

  • 如何开发线下陪玩儿小程序
  • 【图像处理基石】DCT在图像处理中的应用及实现
  • natapp 内网穿透
  • 【iOS】Masnory自动布局的简单学习
  • 图算法详解:最短路径、拓扑排序与关键路径
  • 使用 httpsok 工具全面排查网站安全配置
  • Nginx + Certbot配置 HTTPS / SSL 证书(简化版已测试)
  • Android稳定性问题的常见原因是什么
  • JSP程序设计之JSP指令
  • react+vite+ts 组件模板
  • CVPR2025丨VL2Lite:如何将巨型VLM的“知识”精炼后灌入轻量网络?这项蒸馏技术实现了任务专用的极致压缩
  • 传统星型拓扑结构的5G,WiFi无线通信网络与替代拓扑结构自组网
  • BGP路由协议(一):基本概念
  • UE的SimpleUDPTCPSocket插件使用
  • 百度地图+vue+flask+爬虫 推荐算法旅游大数据可视化系统Echarts mysql数据库 带沙箱支付+图像识别技术
  • 【数字黑洞2178】2022-10-28
  • Linux学习-TCP并发服务器构建(epoll)
  • 【C++】C++11的右值引用和移动语义
  • Unity游戏打包——iOS打包基础、上传
  • 使用Docker部署ZLMediaKit流媒体服务器实现gb/t28181协议的设备
  • Day30 多线程编程 同步与互斥 任务队列调度
  • ArcGIS学习-12 实战-综合案例
  • Unity游戏打包——iOS打包pod的重装和使用
  • Flutter:ios打包ipa,证书申请,Xcode打包,完整流程
  • Intern-S1-mini模型结构
  • SpringBoot系列之实现高效批量写入数据
  • 专项智能练习(图形图像基础)
  • 文本处理与模型对比:BERT, Prompt, Regex, TF-IDF
  • 高精度惯性导航IMU价格与供应商
  • [sys-BlueChi] docs | BluechiCtl命令行工具