当前位置: 首页 > news >正文

深度学习:基于自定义 ResNet 的手写数字识别实践(MNIST 数据集)

目录

一、任务背景与模型选择

二、核心原理:残差块与轻量级 ResNet 设计

1. 残差块设计(ResBlock)

2. 整体网络结构

三、完整代码实现与逐段解析

1. 环境依赖

2. 完整代码与逐段解析

四、训练过程与结果分析

五、总结


一、任务背景与模型选择

手写数字识别是计算机视觉领域的 “入门标杆任务”,其核心是从灰度图像中提取特征并分类(0-9 共 10 类)。常用数据集 MNIST 包含 60000 张训练集和 10000 张测试集,每张图片为 28×28 的灰度图,虽数据规模小、特征简单,但传统 CNN 在深层训练时易出现 “梯度消失” 或 “退化问题”。

本次实践未采用预训练的 ResNet-18(原模型为 3 通道 RGB 输入,适配 MNIST 需额外处理),而是自定义轻量级 ResNet

  • 针对 MNIST 1 通道灰度图设计输入层,避免通道转换冗余;
  • 简化残差块结构,减少计算量(适配 CPU/GPU 轻量化训练);
  • 保留 ResNet 核心的 “跳跃连接”,解决深层网络训练问题。

该方案的优势在于:无需迁移学习适配,从零训练即可快速收敛,同时让初学者直观理解残差网络的核心逻辑。

二、核心原理:残差块与轻量级 ResNet 设计

ResNet 的核心是残差块(Residual Block),通过 “跳跃连接(Skip Connection)” 让梯度直接回传,避免梯度消失。本次自定义的 ResNet 模型针对 MNIST 数据特点做了 3 点优化:

1. 残差块设计(ResBlock)

传统 ResNet 残差块多为 “3×3 卷积→BN→ReLU” 的组合,本次简化为 “5×3 卷积组合”,在保证特征提取能力的同时减少参数:

  • 输入特征图与输出特征图通道数一致(通过channels_in控制),确保 “跳跃连接” 时可直接元素相加;
  • 先通过 5×5 卷积扩大感受野(捕捉数字轮廓),再通过 3×3 卷积细化特征,最后与原始输入相加并 ReLU 激活。

残差块前向传播公式:

其中x为原始输入(跳跃连接路径),为残差路径。

2. 整体网络结构

自定义 ResNet 针对 28×28 灰度图设计,共包含 “特征提取层→残差块→分类层” 三部分,具体结构如下:

三、完整代码实现与逐段解析

本次实践基于 PyTorch 框架,代码包含 “数据加载→模型定义→训练测试→结果输出” 全流程,可直接复制运行(自动下载 MNIST 数据集)。

1. 环境依赖

确保安装以下库(Python 3.8+,PyTorch 1.10+):

pip install torch torchvision matplotlib

2. 完整代码与逐段解析

# -------------------------- 1. 导入必要库 --------------------------
import torch
from torch import nn  # 神经网络核心模块
from torch.utils.data import DataLoader  # 批量加载数据
from torchvision import datasets  # 加载MNIST数据集
from torchvision.transforms import ToTensor  # 图像转为Tensor
from matplotlib import pyplot as plt  # 可选:可视化数据(本文暂未用)# -------------------------- 2. 加载并预处理MNIST数据集 --------------------------
# 加载训练集:root为数据保存路径,train=True表示训练集,download=True自动下载
train_data = datasets.MNIST(root='data',  # 数据保存在./data目录下(不存在则自动创建)train=True,download=True,transform=ToTensor(),  # 转为Tensor:维度H×W×C→C×H×W,值归一化到0-1
)# 加载测试集:train=False表示测试集
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)# 创建数据加载器:按批次加载数据,训练集打乱(shuffle=True)提升泛化能力
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)  # 测试集无需打乱# 查看数据形状:验证输入格式是否正确
for x, y in test_dataloader:print(f"输入图像形状 [批次大小, 通道数, 高度, 宽度]: {x.shape}")  # 输出:torch.Size([64, 1, 28, 28])print(f"标签形状: {y.shape},标签数据类型: {y.dtype}")  # 输出:torch.Size([64]) torch.int64break  # 仅查看第一个批次# -------------------------- 3. 配置训练设备(CPU/GPU自动适配) --------------------------
device = ("cuda"  # NVIDIA GPUif torch.cuda.is_available()else "mps"  # 苹果M系列芯片GPUif torch.backends.mps.is_available()else "cpu"  # 无GPU则用CPU
)
print(f"使用训练设备: {device}")# -------------------------- 4. 定义残差块(ResBlock)与完整ResNet模型 --------------------------
# 残差块:ResNet的核心组件,实现跳跃连接
class ResBlock(nn.Module):def __init__(self, channels_in):super().__init__()  # 继承nn.Module的初始化方法# 残差路径:两次卷积(5×5→3×3)self.conv1 = nn.Conv2d(channels_in, 32, 5, padding=2)  # 5×5卷积,输出32通道,padding=2保证尺寸不变self.conv2 = nn.Conv2d(32, channels_in, 3, padding=1)  # 3×3卷积,输出通道数与输入一致(适配跳跃连接)self.relu = nn.ReLU()  # 激活函数def forward(self, x):# 残差路径计算out = self.conv1(x)out = self.conv2(out)# 跳跃连接:残差路径结果 + 原始输入,再激活return self.relu(out + x)# 完整ResNet模型:针对MNIST设计的轻量级版本
class ResNet(nn.Module):def __init__(self):super().__init__()self.relu = nn.ReLU()  # 激活函数(复用)# 初始卷积层:1→64通道,5×5卷积(捕捉大尺度边缘)self.conv1 = nn.Conv2d(1, 64, 5, 1, 2)  # 输入1通道(灰度图),输出64通道,padding=2保证尺寸不变# 二次卷积层:64→128通道,3×3卷积(细化局部特征)self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)  # padding=1保证尺寸不变self.maxpool = nn.MaxPool2d(2)  # 2×2最大池化(降维减参)# 残差块:分别对应64通道和128通道的特征图self.resblock1 = ResBlock(channels_in=64)self.resblock2 = ResBlock(channels_in=128)# 全连接层:输入为展平后的特征数(128×7×7=6272),输出10类self.full_c = nn.Linear(6272, 10)def forward(self, x):size = x.shape[0]  # 获取批次大小(用于后续展平)# 第一层:卷积→池化→激活x = self.maxpool(self.conv1(x))  # Conv1→MaxPool:28×28→14×14x = self.relu(x)x = self.resblock1(x)  # 残差块1:处理64通道特征图# 第二层:卷积→池化→激活x = self.maxpool(self.conv2(x))  # Conv2→MaxPool:14×14→7×7x = self.relu(x)x = self.resblock2(x)  # 残差块2:处理128通道特征图# 展平特征图:从[batch, 128, 7, 7]→[batch, 6272]x = x.view(size, -1)  # -1表示自动计算剩余维度# 全连接层分类x = self.full_c(x)return x# 初始化模型并转移到目标设备
model = ResNet().to(device)
print("\n自定义残差神经网络结构:")
print(model)  # 打印模型结构,验证是否正确# -------------------------- 5. 定义训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""训练函数:单轮训练,更新模型参数"""model.train()  # 开启训练模式(启用Dropout/BN更新等)batch_size_num = 0  # 批次计数器,用于打印日志for x, y in dataloader:# 将数据转移到训练设备(CPU/GPU)x, y = x.to(device), y.to(device)# 前向传播:计算模型预测值pred = model(x)# 计算损失(多分类任务用CrossEntropyLoss)loss = loss_fn(pred, y)# 反向传播:更新参数optimizer.zero_grad()  # 梯度清零(避免累积)loss.backward()  # 计算梯度optimizer.step()  # 根据梯度更新参数# 记录损失值,每100个批次打印一次日志loss_value = loss.item()  # 从Tensor中提取损失值(避免计算图占用内存)batch_size_num += 1if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [batch: {batch_size_num}]")def test(dataloader, model, loss_fn):"""测试函数:评估模型在测试集上的性能(无参数更新)"""size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)  # 测试集总批次model.eval()  # 开启评估模式(冻结BN/Dropout)test_loss, correct = 0, 0  # 累计测试损失和正确预测数# 关闭梯度计算(测试阶段无需反向传播,节省内存和时间)with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)# 累计损失(按批次累加)test_loss += loss_fn(pred, y).item()# 计算正确预测数:pred.argmax(1)取每行最大值索引(预测类别),与标签比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches  # 平均损失 = 总损失 / 批次数量correct /= size  # 准确率 = 正确数 / 总样本数print(f"Test result: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")# -------------------------- 6. 配置训练超参数并启动训练 --------------------------
# 损失函数:多分类任务用CrossEntropyLoss(内置Softmax,无需手动添加)
loss_fn = nn.CrossEntropyLoss()
# 优化器:Adam优化器,学习率0.0001(小学习率避免震荡)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练轮次:50轮(MNIST数据简单,50轮足够收敛)
epochs = 50# 启动训练循环:每轮训练后测试
for t in range(epochs):print(f"Epoch {t + 1}\n--------------")train(train_dataloader, model, loss_fn, optimizer)  # 单轮训练test(test_dataloader, model, loss_fn)  # 单轮测试
print("Training Done!")  # 训练结束

四、训练过程与结果分析

五、总结

本次基于自定义 ResNet 的 MNIST 手写数字识别实践,以 “简化、适配、高效” 为核心,既验证了残差网络在简单任务中的有效性,也为初学者提供了 “从模型设计到训练落地” 的完整实战路径。实践表明:残差网络的价值不仅在于 “深层”,更在于 “通过跳跃连接解决训练难题”;针对任务特点的轻量化设计,往往比直接套用复杂预训练模型更具性价比。后续可基于此框架,扩展至更复杂的图像分类任务(如 Fashion-MNIST、CIFAR-10),进一步深化对残差网络的理解与应用。


文章转载自:

http://NJuw5qsY.ymwny.cn
http://kE4nl9pD.ymwny.cn
http://AwVbFmVq.ymwny.cn
http://n7wgEHAN.ymwny.cn
http://1n0ZfAR8.ymwny.cn
http://3shWn5b3.ymwny.cn
http://5NhVK1Sm.ymwny.cn
http://mu0c8qs5.ymwny.cn
http://sxqNtsUC.ymwny.cn
http://UtvluDKO.ymwny.cn
http://KnIXrie3.ymwny.cn
http://OPDTHBZo.ymwny.cn
http://W3pedUqv.ymwny.cn
http://yk990PpJ.ymwny.cn
http://AgrFMSG4.ymwny.cn
http://NFT7HypG.ymwny.cn
http://g25eZIaf.ymwny.cn
http://zQCuCYMM.ymwny.cn
http://VLs7hIGE.ymwny.cn
http://LeKSYLzI.ymwny.cn
http://uSN9rWNq.ymwny.cn
http://pEfh6CtK.ymwny.cn
http://TFRF36Fy.ymwny.cn
http://42INN6br.ymwny.cn
http://KJ5z3t1t.ymwny.cn
http://Xu4mJksW.ymwny.cn
http://RQQ6oXmK.ymwny.cn
http://71tVeitj.ymwny.cn
http://9YLnBHG9.ymwny.cn
http://vP3cOWWt.ymwny.cn
http://www.dtcms.com/a/366931.html

相关文章:

  • Day35 网络协议与数据封装
  • Vue 3 学习路线指南
  • C语言基础:内存管理
  • 大模型应用开发框架 LangChain
  • Deeplizard深度学习课程(六)—— 结合Tensorboard进行结果分析
  • 小程序:12亿用户的入口,企业数字化的先锋军
  • 【C++题解】关联容器
  • 15,FreeRTOS计数型信号量操作
  • PMP新考纲练习题10道【附答案解析】
  • 开源技术助力企业腾飞,九识智能迈入‘数据驱动’新纪元
  • Docker(①安装)
  • [Windows] PDF工具箱 PDF24 Creator 11.28.0
  • 阿里云轻量应用服务器部署-WooCommerce
  • Java全栈开发面试实战:从基础到高并发的深度解析
  • 并非银弹,而是利器:对软件开发工具的深度探讨与理性思考
  • 使用 Sentry 为 PHP 和 Web 移动小程序提供多平台错误监控
  • 文心iRAG - 百度推出的检索增强的文生图技术,支持生成超真实图片
  • node的模块查找策略
  • HarmonyOS应用开发之界面列表不刷新问题Bug排查记:从现象到解决完整记录
  • 如何架设游戏服务器
  • 如何配置安全的 SFTP 服务器?
  • 【连载 1/9】大模型基础入门学习60页大模型应用:(一)绪论【附全文阅读】
  • Vue基础知识-脚手架开发-初始化目录解析
  • Java面试-HashMap原理
  • 开关电源——只需这三个阶段,从电源小白到维修大神
  • Pydantic模型验证测试:你的API数据真的安全吗?
  • Linux高手才知道的C++高性能I/O秘诀:Vector I/O与DMA深度解析
  • DRMOS电源
  • 经典资金安全案例分享:支付系统开发的血泪教训
  • 手机秒变全栈IDE:Claude Code UI的深度体验