BatchNorm2d详细原理介绍
为什么需要BatchNorm?
深层网络中,每层的输入分布会随着训练而变化,导致训练不稳定,需要更小的学习率。
BatchNorm作用
- 保持每层输入的分布稳定;
- 允许更大学习率,训练快速收敛;
- 有一定正则化效果,减少对dropout的依赖;
- 缓解梯度消失/爆炸
BatchNorm工作原理
如nn.BatchNorm2d(3):
input_tensor = torch.randn(4, 3, 32, 100) # batch=4个样本,3个通道,32x100特征图
batchnorm = nn.BatchNorm2d(3) # 参数3对应输入通道数
output = batchnorm(input_tensor)
具体执行操作如下:
1. 按通道独立计算统计量
# 对于每个通道,在整个batch和空间维度上计算:
for c in range(3): # 提取第c个通道的所有数据 [4, 32, 100] -> 4*32*100=12800个值channel_data = input_tensor[:, c, :, :]# 计算该通道的均值和方差mean_c = channel_data.m