深度学习:基于自定义 ResNet 的手写数字识别实践(MNIST 数据集)
目录
一、任务背景与模型选择
二、核心原理:残差块与轻量级 ResNet 设计
1. 残差块设计(ResBlock)
2. 整体网络结构
三、完整代码实现与逐段解析
1. 环境依赖
2. 完整代码与逐段解析
四、训练过程与结果分析
五、总结
一、任务背景与模型选择
手写数字识别是计算机视觉领域的 “入门标杆任务”,其核心是从灰度图像中提取特征并分类(0-9 共 10 类)。常用数据集 MNIST 包含 60000 张训练集和 10000 张测试集,每张图片为 28×28 的灰度图,虽数据规模小、特征简单,但传统 CNN 在深层训练时易出现 “梯度消失” 或 “退化问题”。
本次实践未采用预训练的 ResNet-18(原模型为 3 通道 RGB 输入,适配 MNIST 需额外处理),而是自定义轻量级 ResNet:
- 针对 MNIST 1 通道灰度图设计输入层,避免通道转换冗余;
- 简化残差块结构,减少计算量(适配 CPU/GPU 轻量化训练);
- 保留 ResNet 核心的 “跳跃连接”,解决深层网络训练问题。
该方案的优势在于:无需迁移学习适配,从零训练即可快速收敛,同时让初学者直观理解残差网络的核心逻辑。
二、核心原理:残差块与轻量级 ResNet 设计
ResNet 的核心是残差块(Residual Block),通过 “跳跃连接(Skip Connection)” 让梯度直接回传,避免梯度消失。本次自定义的 ResNet 模型针对 MNIST 数据特点做了 3 点优化:
1. 残差块设计(ResBlock)
传统 ResNet 残差块多为 “3×3 卷积→BN→ReLU” 的组合,本次简化为 “5×3 卷积组合”,在保证特征提取能力的同时减少参数:
- 输入特征图与输出特征图通道数一致(通过
channels_in
控制),确保 “跳跃连接” 时可直接元素相加; - 先通过 5×5 卷积扩大感受野(捕捉数字轮廓),再通过 3×3 卷积细化特征,最后与原始输入相加并 ReLU 激活。
残差块前向传播公式:
其中x
为原始输入(跳跃连接路径),为残差路径。
2. 整体网络结构
自定义 ResNet 针对 28×28 灰度图设计,共包含 “特征提取层→残差块→分类层” 三部分,具体结构如下:
三、完整代码实现与逐段解析
本次实践基于 PyTorch 框架,代码包含 “数据加载→模型定义→训练测试→结果输出” 全流程,可直接复制运行(自动下载 MNIST 数据集)。
1. 环境依赖
确保安装以下库(Python 3.8+,PyTorch 1.10+):
pip install torch torchvision matplotlib
2. 完整代码与逐段解析
# -------------------------- 1. 导入必要库 --------------------------
import torch
from torch import nn # 神经网络核心模块
from torch.utils.data import DataLoader # 批量加载数据
from torchvision import datasets # 加载MNIST数据集
from torchvision.transforms import ToTensor # 图像转为Tensor
from matplotlib import pyplot as plt # 可选:可视化数据(本文暂未用)# -------------------------- 2. 加载并预处理MNIST数据集 --------------------------
# 加载训练集:root为数据保存路径,train=True表示训练集,download=True自动下载
train_data = datasets.MNIST(root='data', # 数据保存在./data目录下(不存在则自动创建)train=True,download=True,transform=ToTensor(), # 转为Tensor:维度H×W×C→C×H×W,值归一化到0-1
)# 加载测试集:train=False表示测试集
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)# 创建数据加载器:按批次加载数据,训练集打乱(shuffle=True)提升泛化能力
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False) # 测试集无需打乱# 查看数据形状:验证输入格式是否正确
for x, y in test_dataloader:print(f"输入图像形状 [批次大小, 通道数, 高度, 宽度]: {x.shape}") # 输出:torch.Size([64, 1, 28, 28])print(f"标签形状: {y.shape},标签数据类型: {y.dtype}") # 输出:torch.Size([64]) torch.int64break # 仅查看第一个批次# -------------------------- 3. 配置训练设备(CPU/GPU自动适配) --------------------------
device = ("cuda" # NVIDIA GPUif torch.cuda.is_available()else "mps" # 苹果M系列芯片GPUif torch.backends.mps.is_available()else "cpu" # 无GPU则用CPU
)
print(f"使用训练设备: {device}")# -------------------------- 4. 定义残差块(ResBlock)与完整ResNet模型 --------------------------
# 残差块:ResNet的核心组件,实现跳跃连接
class ResBlock(nn.Module):def __init__(self, channels_in):super().__init__() # 继承nn.Module的初始化方法# 残差路径:两次卷积(5×5→3×3)self.conv1 = nn.Conv2d(channels_in, 32, 5, padding=2) # 5×5卷积,输出32通道,padding=2保证尺寸不变self.conv2 = nn.Conv2d(32, channels_in, 3, padding=1) # 3×3卷积,输出通道数与输入一致(适配跳跃连接)self.relu = nn.ReLU() # 激活函数def forward(self, x):# 残差路径计算out = self.conv1(x)out = self.conv2(out)# 跳跃连接:残差路径结果 + 原始输入,再激活return self.relu(out + x)# 完整ResNet模型:针对MNIST设计的轻量级版本
class ResNet(nn.Module):def __init__(self):super().__init__()self.relu = nn.ReLU() # 激活函数(复用)# 初始卷积层:1→64通道,5×5卷积(捕捉大尺度边缘)self.conv1 = nn.Conv2d(1, 64, 5, 1, 2) # 输入1通道(灰度图),输出64通道,padding=2保证尺寸不变# 二次卷积层:64→128通道,3×3卷积(细化局部特征)self.conv2 = nn.Conv2d(64, 128, 3, 1, 1) # padding=1保证尺寸不变self.maxpool = nn.MaxPool2d(2) # 2×2最大池化(降维减参)# 残差块:分别对应64通道和128通道的特征图self.resblock1 = ResBlock(channels_in=64)self.resblock2 = ResBlock(channels_in=128)# 全连接层:输入为展平后的特征数(128×7×7=6272),输出10类self.full_c = nn.Linear(6272, 10)def forward(self, x):size = x.shape[0] # 获取批次大小(用于后续展平)# 第一层:卷积→池化→激活x = self.maxpool(self.conv1(x)) # Conv1→MaxPool:28×28→14×14x = self.relu(x)x = self.resblock1(x) # 残差块1:处理64通道特征图# 第二层:卷积→池化→激活x = self.maxpool(self.conv2(x)) # Conv2→MaxPool:14×14→7×7x = self.relu(x)x = self.resblock2(x) # 残差块2:处理128通道特征图# 展平特征图:从[batch, 128, 7, 7]→[batch, 6272]x = x.view(size, -1) # -1表示自动计算剩余维度# 全连接层分类x = self.full_c(x)return x# 初始化模型并转移到目标设备
model = ResNet().to(device)
print("\n自定义残差神经网络结构:")
print(model) # 打印模型结构,验证是否正确# -------------------------- 5. 定义训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""训练函数:单轮训练,更新模型参数"""model.train() # 开启训练模式(启用Dropout/BN更新等)batch_size_num = 0 # 批次计数器,用于打印日志for x, y in dataloader:# 将数据转移到训练设备(CPU/GPU)x, y = x.to(device), y.to(device)# 前向传播:计算模型预测值pred = model(x)# 计算损失(多分类任务用CrossEntropyLoss)loss = loss_fn(pred, y)# 反向传播:更新参数optimizer.zero_grad() # 梯度清零(避免累积)loss.backward() # 计算梯度optimizer.step() # 根据梯度更新参数# 记录损失值,每100个批次打印一次日志loss_value = loss.item() # 从Tensor中提取损失值(避免计算图占用内存)batch_size_num += 1if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [batch: {batch_size_num}]")def test(dataloader, model, loss_fn):"""测试函数:评估模型在测试集上的性能(无参数更新)"""size = len(dataloader.dataset) # 测试集总样本数num_batches = len(dataloader) # 测试集总批次model.eval() # 开启评估模式(冻结BN/Dropout)test_loss, correct = 0, 0 # 累计测试损失和正确预测数# 关闭梯度计算(测试阶段无需反向传播,节省内存和时间)with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)# 累计损失(按批次累加)test_loss += loss_fn(pred, y).item()# 计算正确预测数:pred.argmax(1)取每行最大值索引(预测类别),与标签比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches # 平均损失 = 总损失 / 批次数量correct /= size # 准确率 = 正确数 / 总样本数print(f"Test result: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")# -------------------------- 6. 配置训练超参数并启动训练 --------------------------
# 损失函数:多分类任务用CrossEntropyLoss(内置Softmax,无需手动添加)
loss_fn = nn.CrossEntropyLoss()
# 优化器:Adam优化器,学习率0.0001(小学习率避免震荡)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练轮次:50轮(MNIST数据简单,50轮足够收敛)
epochs = 50# 启动训练循环:每轮训练后测试
for t in range(epochs):print(f"Epoch {t + 1}\n--------------")train(train_dataloader, model, loss_fn, optimizer) # 单轮训练test(test_dataloader, model, loss_fn) # 单轮测试
print("Training Done!") # 训练结束
四、训练过程与结果分析
五、总结
本次基于自定义 ResNet 的 MNIST 手写数字识别实践,以 “简化、适配、高效” 为核心,既验证了残差网络在简单任务中的有效性,也为初学者提供了 “从模型设计到训练落地” 的完整实战路径。实践表明:残差网络的价值不仅在于 “深层”,更在于 “通过跳跃连接解决训练难题”;针对任务特点的轻量化设计,往往比直接套用复杂预训练模型更具性价比。后续可基于此框架,扩展至更复杂的图像分类任务(如 Fashion-MNIST、CIFAR-10),进一步深化对残差网络的理解与应用。