残差连接的概念与作用
核心概念:
残差连接是一种在深度神经网络(特别是非常深的网络)中使用的技术,它的核心思想是允许信息从较浅的层直接“跳跃”或“短路”到较深的层。它是由何恺明等人在2015年提出的ResNet(残差网络)中首次引入的,并彻底改变了深度学习的格局,使得训练成百上千层的超深网络成为可能。
为什么需要残差连接?
在残差连接出现之前,人们发现一个悖论:简单地堆叠更多的层来构建更深的网络,反而会导致训练误差和测试误差都增大。这被称为退化问题。
造成退化的主要原因之一是梯度消失/爆炸:
- 梯度消失: 在反向传播过程中,梯度需要从网络的输出层逐层传递回输入层。当网络非常深时,梯度在传递过程中会不断被乘以小于1的数(例如Sigmoid或Tanh激活函数的导数),导致梯度值越来越小,最终趋近于零。这使得浅层网络的权重几乎得不到有效的更新,学习停滞。
- 梯度爆炸: 相反,如果梯度值大于1,在反向传播过程中会不断被放大,导致权重更新过大,网络变得不稳定。
- 优化困难: 即使解决了梯度问题,深层网络的优化曲面可能变得更加复杂和非凸,使得优化算法更难找到好的解。
残差连接如何工作?
残差连接通过引入一个“捷径连接”来解决上述问题。让我们对比一下传统的网络层和带有残差连接的层:
-
传统层(Plain Layer):
- 输入:
x
- 输出:
y = F(x)
- 其中
F(x)
代表该层(或连续的几层)要学习的映射函数(通常是卷积、激活、归一化等操作的组合)。
- 输入:
-
带有残差连接的层(Residual Block):
- 输入:
x
- 主干路径输出:
F(x)
(与传统层相同) - 捷径连接: 直接将输入
x
传递过来。 - 最终输出:
H(x) = F(x) + x
- 目标映射
H(x)
被分解为两部分:原始输入x
和一个残差函数F(x) = H(x) - x
。
- 输入:
图解:
想象一个由两层(或更多层)组成的“块”:
输入 x|v
+---------+ +---------+
| Layer1 | --> | Layer2 | --> F(x)
+---------+ +---------+| || v+--------------> (+) --> 输出 H(x) = F(x) + x
箭头 ----->
表示主路径(需要学习的变换 F(x)
)。
箭头 -------------->
表示捷径连接(Identity Mapping,恒等映射),它直接将输入 x
绕过主路径,加到主路径的输出 F(x)
上。
关键优势和工作原理:
-
解决梯度消失/爆炸:
- 在反向传播时,梯度不仅可以通过主路径
F(x)
传播回来,还可以直接通过捷径连接传播回来。 - 即使主路径的梯度变得非常小(接近0),梯度仍然可以通过捷径连接相对无损地传递到更浅的层(因为
dH(x)/dx = d(F(x) + x)/dx = dF(x)/dx + 1 ≈ 1
,只要dF(x)/dx
不太负)。这确保了浅层网络能够获得足够强的梯度信号进行有效的学习更新。 - 同样,梯度爆炸的风险也降低了,因为梯度有两条路径可以分散。
- 在反向传播时,梯度不仅可以通过主路径
-
缓解退化问题:
- 网络不再需要学习完整的映射
H(x)
,而是学习残差F(x) = H(x) - x
。 - 学习“从输入到输出的变化量”通常比学习“整个输出”要容易得多。特别是当最优的
H(x)
非常接近x
(即残差接近0)时,网络只需将F(x)
推向0,这比学习一个恒等映射(H(x) = x
)更容易(因为权重初始化为接近0,F(x)
初始值也接近0)。 - 如果增加的层是冗余的(即最优解就是
H(x) = x
),那么网络可以轻松地将F(x)
学习为0,使得H(x) = x
,这样增加的层就不会降低性能。而在普通网络中,学习H(x) = x
可能并不容易。
- 网络不再需要学习完整的映射
-
促进信息流动:
- 捷径连接为信息提供了一条高速公路,允许原始特征信息无损地传递到更深的层。这有助于保留早期提取的低级特征(如边缘、纹理),并与深层提取的高级特征进行融合,可能提升模型的表示能力。
-
使训练极深网络成为可能:
- 正是由于解决了梯度问题和退化问题,ResNet才能成功训练出152层甚至1000层以上的网络,并在ImageNet等基准上取得了当时最好的性能。
实际实现细节:
- 维度匹配: 当捷径连接的输入
x
和主路径输出F(x)
的维度(通道数、高、宽)不一致时(例如下采样时),需要在捷径连接上做变换(通常是1x1卷积)来匹配维度,然后再相加。公式变为H(x) = F(x) + W_s * x
,其中W_s
是投影矩阵(1x1卷积)。 - 激活函数位置: 通常,激活函数(如ReLU)放在加法操作之后。但在ResNet v2等改进版本中,提出了“预激活”结构(BN->ReLU->Conv),将激活放在卷积之前,理论上具有更好的性能。
总结:
残差连接是一种革命性的神经网络结构设计,它通过引入跨层的恒等捷径连接,将需要学习的映射转化为学习残差函数。这种方法有效地解决了深度神经网络训练中的梯度消失/爆炸问题和退化问题,使得训练成百上千层的超深网络成为现实,极大地推动了深度学习的发展。它已成为现代神经网络架构(如ResNet及其变种、Transformer等)中不可或缺的基础组件。