残差网络实战:基于MNIST数据集的手写数字识别
残差网络实战:基于MNIST数据集的手写数字识别
在深度学习的广阔领域中,卷积神经网络(CNN)一直是处理图像任务的主力军。随着研究的深入,网络层数的增加虽然理论上能提升模型的表达能力,但却面临梯度消失、梯度爆炸以及网络退化等问题。残差网络(ResNet)的出现,成功地解决了这些难题,为深度学习的发展开辟了新的道路。本文将通过在MNIST数据集上进行手写数字识别的实战,带大家深入了解残差网络的原理与应用。
一、MNIST数据集介绍
MNIST数据集是机器学习领域中非常经典的图像数据集,它包含了70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28×28像素,并且已经进行了居中处理,大大减少了预处理的工作量,同时也加快了模型的运行速度。在本文的实战中,MNIST数据集将作为我们训练和测试残差网络的“战场”。
二、残差网络原理
传统的神经网络在增加层数时,容易出现网络退化现象,即随着网络层数的增加,模型在训练集和测试集上的性能反而下降。残差网络通过引入残差块(Residual Block)巧妙地解决了这一问题。
残差块的核心思想是让网络学习输入与输出之间的残差,而不是直接学习复杂的映射关系。在一个残差块中,输入数据会经过一系列的卷积、激活等操作,得到一个输出,同时输入数据会直接通过一个快捷连接(Shortcut Connection)与输出相加。这样,网络学习的目标就变成了输出与输入之间的差异,使得训练过程更加容易。
数学上,假设残差块的输入为 x x x,期望的输出映射为 H ( x ) H(x) H(x),通过残差块学习到的函数为 F ( x ) F(x) F(x),那么残差块的输出可以表示为 y = F ( x ) + x y = F(x) + x y=F(x)+x。当 F ( x ) = 0 F(x) = 0 F(x)=0时,残差块的输出就等于输入,这保证了即使增加网络层数,模型的性能也不会下降,反而有机会通过学习残差来提升性能。
三、代码实现与解析
1. 数据加载与预处理
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下载训练数据集(包含训练图片+标签)
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 下载测试数据集(包含训练图片+标签)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# 创建数据DataLoader(数据加载器)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
上述代码利用torchvision
库中的datasets.MNIST
类下载MNIST数据集,并使用ToTensor()
将图像数据转换为PyTorch能够处理的张量格式。然后,通过DataLoader
将数据集划分为大小为64的批次,这样做可以减少内存的使用,提高训练速度。
2. 设备配置
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
这段代码用于判断当前设备是否支持GPU(CUDA或苹果M系列芯片的MPS),如果支持则使用GPU进行计算,否则使用CPU,充分利用硬件资源加速模型训练。
3. 残差块与网络定义
# 模块搭建
class ResBlock(nn.Module):def __init__(self, channels_in):super().__init__()self.conv1 = torch.nn.Conv2d(channels_in, 30, 5, padding=2)self.conv2 = torch.nn.Conv2d(30, channels_in, 3, padding=1)def forward(self, x):out = self.conv1(x)out = self.conv2(out)return F.relu(out + x)# 网络搭建
class ResNet(nn.Module):def __init__(self):super().__init__()self.conv1 = torch.nn.Conv2d(1, 20, 5)self.conv2 = torch.nn.Conv2d(20, 15, 3)self.maxpool = torch.nn.MaxPool2d(2)self.resblock1 = ResBlock(channels_in=20)self.resblock2 = ResBlock(channels_in=15)self.full_c = torch.nn.Linear(375, 10)def forward(self, x):size = x.shape[0]x = F.relu(self.maxpool(self.conv1(x)))x = self.resblock1(x)x = F.relu(self.maxpool(self.conv2(x)))x = self.resblock2(x)x = x.view(size, -1)x = self.full_c(x)return xmodel = ResNet().to(device)
在上述代码中,首先定义了ResBlock
类,实现了残差块的结构,其中包含两个卷积层,并通过快捷连接将输入与卷积层的输出相加,再经过ReLU激活函数得到最终输出。接着,ResNet
类构建了完整的残差网络,包含普通卷积层、最大池化层、残差块以及全连接层,将输入图像逐步提取特征并分类为10个数字类别。
4. 训练与测试函数
def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss = loss.item()print(f"loss: {loss:>7f} [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")
train
函数用于模型的训练过程,在每个批次中,将数据传入模型得到预测结果,通过交叉熵损失函数计算损失,进行反向传播更新模型参数,并打印每一批次的损失值。test
函数则用于在测试集上评估模型性能,关闭梯度计算以节省内存,计算测试集上的平均损失和准确率。
5. 模型训练与评估
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
最后,定义交叉熵损失函数和Adam优化器,设置训练轮数为10,通过循环调用train
函数进行模型训练,训练完成后调用test
函数在测试集上评估模型性能。
四、总结与展望
通过在MNIST数据集上的实战,我们成功地实现了基于残差网络的手写数字识别。残差网络凭借其独特的结构设计,有效地解决了深度神经网络中的退化问题,使得我们能够构建更深、更强大的模型。在未来的研究和应用中,残差网络的思想可以应用到更多复杂的图像任务,如图像分割、目标检测等。同时,结合其他先进的技术,如注意力机制、生成对抗网络等,有望进一步提升模型的性能,为深度学习在计算机视觉领域的发展带来更多的可能性。
希望本文能帮助大家更好地理解残差网络,并在实际项目中灵活运用,开启深度学习图像识别的新征程!