【深度学习系列】ResNet网络原理与mnist手写数字识别实现
一、ResNet简介
ResNet(Residual Network,残差网络)是由何恺明大神等人于2015年提出的一种深度学习网络结构。
通过引入残差模块(Residual Block)和跳跃连接(Skip Connection),解决了传统深度神经网络中梯度消失和梯度爆炸、网络退化(随着网络深度的继续增加,模型准确率毫无征兆的出现大幅度的降低)的问题,使得网络可以训练到非常深的层次(如1000层以上),并显著提升了模型性能。
二、ResNet网络原理
1. 残差学习框架
ResNet的核心思想是残差学习。假设网络的输入为x,期望输出为H(x),ResNet将H(x)分解为F(x)+x,其中F(x)表示残差,x表示输入的恒等映射。这样,网络只需要学习输入与输出之间的残差F(x),而不是直接学习H(x),大大简化了优化过程。
H(x)=F(x)+x,x是输入,H(x)是期望的输出,F(x)是输入和输出之间的差,即残差。举个栗子,输入一张高糊的照片,期望网络对照片进行处理,输出一张高清的结果照片。那么利用resnet的网络模型,如果只有一层的话,就是在原照片x的基础上进行优化清晰度。
2. 残差块结构
残差块是ResNet的基本构建单元。一个典型的残差块包含两个卷积层,每个卷积层后面接一个批量归一化层(Batch Normalization)和ReLU激活函数。残差块的输出是输入x与卷积层输出F(x)的和,即H(x)=F(x)+x。如果输入和输出的特征图维度不一致,可以通过1×1卷积进行升维或降维。
类型 | 结构 | 适用网络 |
---|---|---|
BasicBlock | [Conv3x3]-[BN]-[ReLU]×2 | ResNet-18/34 |
Bottleneck | [Conv1x1]-[Conv3x3]-[Conv1x1] | ResNet-50+ |
3. 跳跃连接
跳跃连接是指将前面若干层的输出直接连接到后面的层。这种连接方式使得梯度可以直接流过整个网络,缓解了梯度消失问题。跳跃连接不仅允许信息在不同层次之间快速传递,还增强了网络的特征提取能力。
4. 网络结构
ResNet的整体结构由多个残差块堆叠而成。常见的ResNet版本包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152,其中数字表示网络的层数。以ResNet-18为例,其网络结构如下:
-
输入层:一个7×7的卷积层,步幅为2,输出通道数为64,后面接一个3×3的最大池化层,步幅为2。
-
残差块组:包含4个残差块模块,每个模块包含若干个残差块。第一个模块的通道数与输入通道数一致,后续模块的通道数逐个翻倍,同时高和宽减半。
-
输出层:全局平均池化层,将特征图压缩为1×1,然后通过一个全连接层输出分类结果。
三、ResNet的实现
核心代码如下:
def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):identity = self.shortcut(x)out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += identityreturn F.relu(out)
相加之前需要判断x和F(x)的通道数和宽高是否一致,否则需要使用1×1卷积进行调整之后再相加。
完整代码见我的github仓库:https://github.com/lovesuger/ResNet.git