当前位置: 首页 > news >正文

残差网络实战:基于MNIST数据集的手写数字识别

残差网络实战:基于MNIST数据集的手写数字识别

在深度学习的广阔领域中,卷积神经网络(CNN)一直是处理图像任务的主力军。随着研究的深入,网络层数的增加虽然理论上能提升模型的表达能力,但却面临梯度消失、梯度爆炸以及网络退化等问题。残差网络(ResNet)的出现,成功地解决了这些难题,为深度学习的发展开辟了新的道路。本文将通过在MNIST数据集上进行手写数字识别的实战,带大家深入了解残差网络的原理与应用。

一、MNIST数据集介绍

MNIST数据集是机器学习领域中非常经典的图像数据集,它包含了70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28×28像素,并且已经进行了居中处理,大大减少了预处理的工作量,同时也加快了模型的运行速度。在本文的实战中,MNIST数据集将作为我们训练和测试残差网络的“战场”。

二、残差网络原理

传统的神经网络在增加层数时,容易出现网络退化现象,即随着网络层数的增加,模型在训练集和测试集上的性能反而下降。残差网络通过引入残差块(Residual Block)巧妙地解决了这一问题。

残差块的核心思想是让网络学习输入与输出之间的残差,而不是直接学习复杂的映射关系。在一个残差块中,输入数据会经过一系列的卷积、激活等操作,得到一个输出,同时输入数据会直接通过一个快捷连接(Shortcut Connection)与输出相加。这样,网络学习的目标就变成了输出与输入之间的差异,使得训练过程更加容易。

数学上,假设残差块的输入为 x x x,期望的输出映射为 H ( x ) H(x) H(x),通过残差块学习到的函数为 F ( x ) F(x) F(x),那么残差块的输出可以表示为 y = F ( x ) + x y = F(x) + x y=F(x)+x。当 F ( x ) = 0 F(x) = 0 F(x)=0时,残差块的输出就等于输入,这保证了即使增加网络层数,模型的性能也不会下降,反而有机会通过学习残差来提升性能。

三、代码实现与解析

1. 数据加载与预处理

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下载训练数据集(包含训练图片+标签)
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 下载测试数据集(包含训练图片+标签)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# 创建数据DataLoader(数据加载器)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

上述代码利用torchvision库中的datasets.MNIST类下载MNIST数据集,并使用ToTensor()将图像数据转换为PyTorch能够处理的张量格式。然后,通过DataLoader将数据集划分为大小为64的批次,这样做可以减少内存的使用,提高训练速度。

2. 设备配置

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这段代码用于判断当前设备是否支持GPU(CUDA或苹果M系列芯片的MPS),如果支持则使用GPU进行计算,否则使用CPU,充分利用硬件资源加速模型训练。

3. 残差块与网络定义

# 模块搭建
class ResBlock(nn.Module):def __init__(self, channels_in):super().__init__()self.conv1 = torch.nn.Conv2d(channels_in, 30, 5, padding=2)self.conv2 = torch.nn.Conv2d(30, channels_in, 3, padding=1)def forward(self, x):out = self.conv1(x)out = self.conv2(out)return F.relu(out + x)# 网络搭建
class ResNet(nn.Module):def __init__(self):super().__init__()self.conv1 = torch.nn.Conv2d(1, 20, 5)self.conv2 = torch.nn.Conv2d(20, 15, 3)self.maxpool = torch.nn.MaxPool2d(2)self.resblock1 = ResBlock(channels_in=20)self.resblock2 = ResBlock(channels_in=15)self.full_c = torch.nn.Linear(375, 10)def forward(self, x):size = x.shape[0]x = F.relu(self.maxpool(self.conv1(x)))x = self.resblock1(x)x = F.relu(self.maxpool(self.conv2(x)))x = self.resblock2(x)x = x.view(size, -1)x = self.full_c(x)return xmodel = ResNet().to(device)

在上述代码中,首先定义了ResBlock类,实现了残差块的结构,其中包含两个卷积层,并通过快捷连接将输入与卷积层的输出相加,再经过ReLU激活函数得到最终输出。接着,ResNet类构建了完整的残差网络,包含普通卷积层、最大池化层、残差块以及全连接层,将输入图像逐步提取特征并分类为10个数字类别。

4. 训练与测试函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss = loss.item()print(f"loss: {loss:>7f}  [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")

train函数用于模型的训练过程,在每个批次中,将数据传入模型得到预测结果,通过交叉熵损失函数计算损失,进行反向传播更新模型参数,并打印每一批次的损失值。test函数则用于在测试集上评估模型性能,关闭梯度计算以节省内存,计算测试集上的平均损失和准确率。

5. 模型训练与评估

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

最后,定义交叉熵损失函数和Adam优化器,设置训练轮数为10,通过循环调用train函数进行模型训练,训练完成后调用test函数在测试集上评估模型性能。

四、总结与展望

通过在MNIST数据集上的实战,我们成功地实现了基于残差网络的手写数字识别。残差网络凭借其独特的结构设计,有效地解决了深度神经网络中的退化问题,使得我们能够构建更深、更强大的模型。在未来的研究和应用中,残差网络的思想可以应用到更多复杂的图像任务,如图像分割、目标检测等。同时,结合其他先进的技术,如注意力机制、生成对抗网络等,有望进一步提升模型的性能,为深度学习在计算机视觉领域的发展带来更多的可能性。

希望本文能帮助大家更好地理解残差网络,并在实际项目中灵活运用,开启深度学习图像识别的新征程!

相关文章:

  • 主机漏洞扫描:如何保障网络安全及扫描原理与类型介绍?
  • JVM 内存结构全解析
  • 【NLP】32. Transformers (HuggingFace Pipelines 实战)
  • 形式化数学——Lean求值表达式
  • Winform(11.案例讲解1)
  • 探寻适用工具:AI+3D 平台与工具的关键能力及选型考量 (AI+3D 产品经理笔记 S2E03)
  • 动态指令参数:根据组件状态调整指令行为
  • MVC、MVP、MVVM三大架构区别
  • 全球化电商平台AWS云架构设计
  • APK 图标提取软件!一键获取应用宝藏图标
  • TS 类类型
  • 关于 dex2oat 以及 vdex、cdex、dex 格式转换
  • Sui 上线两周年,掀起增长「海啸」
  • 一、Hadoop历史发展与优劣势
  • 项目成本管理_挣得进度ES
  • osquery在网络安全入侵场景中的应用实战(二)
  • 【AND-OR-~OR锁存器设计】2022-8-31
  • 深度学习中学习率调整:提升食物图像分类模型性能的关键实践
  • 山东大学项目实训-创新实训-法律文书专家系统-项目报告(三)
  • Linux常用命令31——groupmod更改群组属性
  • 印巴战火LIVE|巴基斯坦多地遭印度导弹袭击,巴总理称“有权作出适当回应”
  • 南方地区强降雨或致部分河流发生超警洪水,水利部部署防范
  • 白俄罗斯政府代表团将访问朝鲜
  • 张求会谈陈寅恪的生前身后事
  • 新华每日电讯“关爱青年成长”三连评:青春应有多样的精彩
  • 上千游客深夜滞留张家界大喊退票?景区:已采取措施限制人流量