当前位置: 首页 > news >正文

卷积神经网络实现mnist手写数字集识别案例

手写数字识别是计算机视觉领域的“Hello World”,也是深度学习入门的经典案例。它通过训练模型识别0-9的手写数字图像(如MNIST数据集),帮助我们快速掌握神经网络的核心流程。本文将以PyTorch框架为基础,带你从数据加载、模型构建到训练评估,完整实现一个手写数字识别系统。

二、数据加载与预处理:认识MNIST数据集

1. MNIST数据集简介

MNIST是手写数字的标准数据集,包含:

  • 训练集:60,000张28x28的灰度图(0-9数字)
  • 测试集:10,000张同尺寸图片
  • 每张图片已归一化(像素值0-1),标签为0-9的整数

2. 代码实现:下载与加载数据

使用torchvision.datasets可直接下载MNIST,transforms.ToTensor()将图片转为PyTorch张量(通道优先格式:[1,28,28],1为灰度通道数)。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下载训练集(60,000张)
train_data = datasets.MNIST(root="data",       # 数据存储路径train=True,        # 标记为训练集download=True,     # 自动下载(首次运行时)transform=ToTensor()  # 转为张量(shape: [1,28,28])
)# 下载测试集(10,000张)
test_data = datasets.MNIST(root="data",train=False,       # 标记为测试集download=True,transform=ToTensor()
)

3. 数据封装:DataLoader批量加载

DataLoader将数据集打包为可迭代的批量数据,支持随机打乱(训练集)、多线程加载等。

device = "cuda" if torch.cuda.is_available() else "cpu"  # 自动选择GPU/CPU
batch_size = 64  # 每批64张图片(可根据显存调整)# 训练集DataLoader(打乱顺序)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 测试集DataLoader(不打乱顺序)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

三、模型构建:设计卷积神经网络(CNN)

1. 为什么选择CNN?

手写数字识别需要捕捉图像的局部特征(如笔画边缘、拐点),而CNN的卷积层通过滑动窗口提取局部模式,池化层降低计算量,全连接层完成分类,非常适合处理图像任务。

2. 模型结构详解(附代码注释)

以下是我们定义的CNN模型,包含3个卷积块和1个全连接输出层:

class CNN(nn.Module):def __init__(self):super().__init__()  # 继承PyTorch模块基类# 卷积块1:输入1通道(灰度图)→ 输出8通道特征图self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,    # 输入通道数(灰度图)out_channels=8,   # 输出8个特征图(8个卷积核)kernel_size=5,    # 卷积核尺寸5x5(覆盖局部区域)stride=1,         # 滑动步长1(不跳跃)padding=2         # 边缘填充2圈0(保持输出尺寸不变)),nn.ReLU(),  # 非线性激活(引入复杂模式)nn.MaxPool2d(kernel_size=2)  # 最大池化(2x2窗口,尺寸减半))# 卷积块2:特征抽象(8→16→32通道)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2),  # 8→16通道,5x5卷积,填充2(尺寸不变)nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2), # 16→32通道,5x5卷积,填充2(尺寸不变)nn.ReLU(),nn.MaxPool2d(kernel_size=2)  # 尺寸减半(14→7))# 卷积块3:特征精炼(32→256通道,保留空间信息)self.conv3 = nn.Sequential(nn.Conv2d(32, 256, 5, 1, 2),  # 32→256通道,5x5卷积,填充2(尺寸不变)nn.ReLU())# 全连接输出层:256*7*7维特征→10类概率self.out = nn.Linear(256 * 7 * 7, 10)  # 10对应0-9数字类别def forward(self, x):"""前向传播:定义数据流动路径"""x = self.conv1(x)  # 输入:[64,1,28,28] → 输出:[64,8,14,14](池化后尺寸减半)x = self.conv2(x)  # 输入:[64,8,14,14] → 输出:[64,32,7,7](两次卷积+池化)x = self.conv3(x)  # 输入:[64,32,7,7] → 输出:[64,256,7,7](仅卷积)x = x.view(x.size(0), -1)  # 展平:[64,256,7,7] → [64,256*7*7](全连接需要一维输入)output = self.out(x)       # 输出:[64,10](每个样本对应10类的得分)return output

3. 关键参数计算(以输入28x28为例)

  • conv1后:卷积核5x5,填充2,输出尺寸(28-5+2*2)/1 +1=28;池化后尺寸28/2=14 → 输出[64,8,14,14]
  • conv2后:两次卷积保持14x14,池化后14/2=7 → 输出[64,32,7,7]
  • conv3后:卷积保持7x7 → 输出[64,256,7,7]
  • 展平后256*7*7=12544维向量 → 全连接到10类

四、训练配置:损失函数与优化器

1. 损失函数:交叉熵损失(CrossEntropyLoss)

手写数字识别是多分类任务,交叉熵损失函数直接衡量模型输出概率与真实标签的差异。PyTorch的nn.CrossEntropyLoss已集成Softmax操作(无需手动添加)。

2. 优化器:随机梯度下降(SGD)

优化器负责根据损失值更新模型参数。这里选择SGD(学习率lr=0.1),简单且对小数据集友好(也可尝试Adam等更复杂的优化器)。

model = CNN().to(device)  # 模型加载到GPU/CPU
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)  # SGD优化器

五、训练循环:让模型“学习”特征

1. 训练逻辑概述

训练过程的核心是“前向传播→计算损失→反向传播→更新参数”,重复直到模型收敛。具体步骤:

  1. 模型设为训练模式(model.train());
  2. 遍历训练数据,按批输入模型;
  3. 计算预测值与真实标签的损失;
  4. 反向传播计算梯度(loss.backward());
  5. 优化器更新参数(optimizer.step());
  6. 清空梯度(optimizer.zero_grad())避免累积。

2. 代码实现:训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启训练模式(影响Dropout/BatchNorm等层)total_loss = 0  # 记录总损失for batch_idx, (x, y) in enumerate(dataloader):x, y = x.to(device), y.to(device)  # 数据加载到GPU/CPU# 1. 前向传播:模型预测pred = model(x)# 2. 计算损失:预测值 vs 真实标签loss = loss_fn(pred, y)total_loss += loss.item()  # 累加批次损失# 3. 反向传播:计算梯度optimizer.zero_grad()  # 清空历史梯度loss.backward()        # 反向传播计算当前梯度# 4. 更新参数:根据梯度调整模型权重optimizer.step()# 每100个批次打印一次损失(监控训练进度)if (batch_idx + 1) % 100 == 0:print(f"批次 {batch_idx+1}/{len(dataloader)}, 当前损失: {loss.item():.4f}")avg_loss = total_loss / len(dataloader)print(f"训练完成,平均损失: {avg_loss:.4f}")

六、测试评估:验证模型泛化能力

1. 测试逻辑概述

测试阶段需关闭模型的随机操作(如Dropout),用测试集评估模型的泛化能力。核心指标是准确率(正确预测的样本比例)。

2. 代码实现:测试函数

def test(dataloader, model):model.eval()  # 开启评估模式(关闭Dropout等随机层)correct = 0   # 记录正确预测数total = 0     # 记录总样本数with torch.no_grad():  # 关闭梯度计算(节省内存)for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)  # 模型预测# 统计正确数:pred.argmax(1)取预测概率最大的类别correct += (pred.argmax(1) == y).sum().item()total += y.size(0)  # 累加批次样本数accuracy = correct / totalprint(f"测试准确率: {accuracy * 100:.2f}%")return accuracy

七、完整训练与结果

1. 运行训练循环

我们训练10个epoch(遍历整个训练集10次):

# 训练10轮
for epoch in range(10):print(f"
====={epoch+1} 轮训练 =====")train(train_dataloader, model, loss_fn, optimizer)# 测试最终效果
print("
===== 最终测试 =====")
test_acc = test(test_dataloader, model)

2. 典型输出结果

假设训练10轮后,测试准确率可能达到98.5%+(具体取决于超参数和硬件):

===== 第 1 轮训练 =====
批次 100/938, 当前损失: 0.2145
...
训练完成,平均损失: 0.1234===== 第 10 轮训练 =====
批次 100/938, 当前损失: 0.0321
...
训练完成,平均损失: 0.0189===== 最终测试 =====
测试准确率: 98.76%

八、改进方向:让模型更强大

当前模型已能较好识别手写数字,但仍有优化空间:

1. 调整超参数

  • 学习率:若损失下降缓慢,降低lr(如0.01);若波动大,增大lr
  • 批量大小:增大batch_size(如128)可加速训练(需更大显存)。
  • 训练轮次:增加epoch(如20轮),但需防止过拟合(训练损失持续下降,测试损失上升)。

2. 添加正则化

  • Batch Normalization:在卷积层后添加nn.BatchNorm2d(out_channels),加速收敛并稳定训练。
    self.conv1 = nn.Sequential(nn.Conv2d(1,8,5,1,2),nn.BatchNorm2d(8),  # 新增nn.ReLU(),nn.MaxPool2d(2)
    )
    
  • Dropout:在全连接层前添加nn.Dropout(p=0.5),随机断开神经元,防止过拟合。
    self.out = nn.Sequential(nn.Dropout(0.5),  # 新增nn.Linear(256*7*7, 10)
    )
    

3. 使用更深的网络

当前模型仅3个卷积块,对于复杂任务(如ImageNet),可使用ResNet等残差网络,通过跳跃连接(Skip Connection)解决深层网络的梯度消失问题。

九、总结

通过本文,你已完成从数据加载到模型训练的全流程,掌握了:

  • 数据预处理:使用torchvision加载标准数据集,DataLoader批量管理数据;
  • 模型构建:设计CNN的核心组件(卷积层、激活函数、池化层);
  • 训练与评估:理解损失函数、优化器的作用,掌握训练循环和测试逻辑。

手写数字识别是深度学习的起点,你可以尝试修改模型结构(如增加卷积层)、更换数据集(如Fashion-MNIST)或调整超参数,进一步探索深度学习的魅力!

动手建议:运行代码时,尝试将device改为cpu(无GPU时),观察训练速度变化;或修改kernel_size(如3x3),对比模型性能差异。

http://www.dtcms.com/a/354986.html

相关文章:

  • Apollo-PETRv1演示DEMO操作指南
  • 【Qt】QCryptographicHash 设置密钥(Key)
  • Deeplizard 深度学习课程(四)—— 模型构建
  • jwt原理及Java中实现
  • 海盗王64位dx9客户端修改篇之二
  • 学习Java29天(tcp多发多收)但是无解决客户端启动多个问题
  • ProfiNet 转 Ethernet/IP 柔性产线构建方案:网关技术保护新能源企业现有设备投资
  • LeetCode Hot 100 第7天
  • 第三十天:世界杯队伍团结力问题
  • EF Core 编译模型 / 模型裁剪:冷启动与查询优化
  • QT之双缓冲 (QMutex/QWaitCondition)——读写分离
  • 企业如何管理跨多个系统的主数据?
  • MaxCompute MaxFrame | 分布式Python计算服务MaxFrame(完整操作版)
  • 【Lua】题目小练12
  • 如何实现HTML动态爱心表白效果?
  • 多版本并发控制MVCC
  • 黑马点评|项目日记(day02)
  • C#和Lua相互访问
  • 基于金庸武侠小说人物关系设计的完整 SQL 语句,包括数据库创建、表结构定义和示例数据插入
  • Docker 详解+示例
  • map底层的数据结构是什么,为什么不用AVL树
  • 机器学习回顾(一)
  • 陪诊小程序系统开发:搭建医患之间的温暖桥梁
  • Scrapy 基础介绍
  • 安全运维——系统上线前安全检测:漏洞扫描、系统基线与应用基线的全面解析
  • lwIP MQTT 心跳 Bug 分析与修复
  • 边缘计算(Edge Computing)+ AI:未来智能世界的核心引擎
  • HarmonyOS 组件与页面生命周期:全面解析与实践
  • Paimon——官网阅读:Flink 引擎
  • 【秋招笔试】2025.08.27华为秋招研发岗真题