【nn.GroupNorm】
组归一化(Group Normalization, GN)是一种深度学习中的归一化方法,旨在解决批量归一化(Batch Normalization, BN)在小批量(mini-batch)较小时性能下降的问题。GN 通过将通道(channels)分组进行归一化,不依赖于批量维度,因此对批量大小不敏感。
GN 的数学公式
给定输入张量 x x x,其维度为 ( N , C , H , W ) (N, C, H, W) (N,C,H,W)(分别表示批量大小、通道数、高度、宽度),GN 的计算步骤如下:
-
分组:
将 C C C 个通道分成 G G G 组,每组包含 C G \frac{C}{G} GC个通道(假设 ( C ) 能被 ( G ) 整除)。
设 x n , g x_{n,g} xn,g表示第 n n n 个样本的第 g g g 组特征(维度为 C G × H × W \frac{C}{G} \times H \times W GC×H×W)。 -
计算均值和方差:
对每个样本的每个组分别计算均值和方差:
μ n , g = 1 ∣ G ∣ ∑ c ∈ G ∑ h , w x n , c , h , w , \mu_{n,g} = \frac{1}{|\mathcal{G}|} \sum_{c \in \mathcal{G}} \sum_{h,w} x_{n,c,h,w}, μn,g=∣G∣1c∈G∑h,w∑xn,c,h,w,
σ n , g 2 = 1 ∣ G ∣ ∑ c ∈ G ∑ h , w ( x n , c , h , w − μ n , g ) 2 , \sigma_{n,g}^2 = \frac{1}{|\mathcal{G}|} \sum_{c \in \mathcal{G}} \sum_{h,w} (x_{n,c,h,w} - \mu_{n,g})^2, σn,g2=∣G∣1c∈G∑h,w∑(xn,c,h,w−μn,g)2,
其中 G \mathcal{G} G表示当前组的所有通道索引, ∣ G ∣ = C G × H × W |\mathcal{G}| = \frac{C}{G} \times H \times W ∣G∣=GC×H×W。 -
归一化:
使用计算得到的均值和方差对输入进行归一化:
x ^ n , c , h , w = x n , c , h , w − μ n , g σ n , g 2 + ϵ , \hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sqrt{\sigma_{n,g}^2 + \epsilon}}, x^n,c,h,w=σn,g2+ϵxn,c,h,w−μn,g,
其中 ϵ \epsilon ϵ 是一个很小的常数(如 10 − 5 10^{-5} 10−5),用于数值稳定性。 -
仿射变换(可选):
类似于 BN,GN 也可以引入可学习的缩放(scale)和平移(shift)参数:
y n , c , h , w = γ c x ^ n , c , h , w + β c , y_{n,c,h,w} = \gamma_c \hat{x}_{n,c,h,w} + \beta_c, yn,c,h,w=γcx^n,c,h,w+βc,
其中 γ c \gamma_c γc和 β c \beta_c βc 是可训练的参数,分别对应缩放和平移。
特点
- 不依赖批量大小:GN 对每个样本独立计算统计量,适用于小批量或单样本训练(如目标检测、视频处理等任务)。
- 组数 ( G ) 是超参数:
- 当 ( G=1 ),GN 退化为 Layer Normalization (LN)(对整个通道归一化)。
- 当 ( G=C ),GN 退化为 Instance Normalization (IN)(每个通道单独归一化)。
- 计算开销比 BN 略高:因为需要逐样本计算统计量,但比 LN 和 IN 更灵活。
PyTorch 实现
在 PyTorch 中,GN 可通过 torch.nn.GroupNorm
实现:
import torch.nn as nn# 假设输入通道数 C=32,组数 G=8
gn = nn.GroupNorm(num_groups=8, num_channels=32)
output = gn(input_tensor) # input_tensor: (N, 32, H, W)
适用场景
- 小批量训练(如 batch size < 16)。
- 任务对批量统计量敏感(如检测、分割、生成模型等)。
GN 在 ResNet、Mask R-CNN 等模型中表现良好,尤其在批量较小时优于 BN。
在 PyTorch 中,nn.GroupNorm
的计算过程可以分为以下几个步骤,我们结合代码和数学公式详细说明。
1. 输入张量的形状
假设输入张量 x
的维度为 (N, C, H, W)
,其中:
N
:batch size(样本数量)C
:通道数(channels)H
、W
:特征图的高度和宽度
例如,x.shape = (2, 6, 3, 3)
表示:
N = 2
(2 个样本)C = 6
(6 个通道)H = W = 3
(3×3 的特征图)
2. 分组(Grouping)
nn.GroupNorm
的关键是将通道 C
分成 G
组(G
是超参数)。
- 每组通道数:
C_per_group = C // G
- 要求
C
必须能被G
整除,否则会报错。
示例:
如果 C = 6
,G = 2
,则:
- 每组有
6 // 2 = 3
个通道 - 组 0:通道
[0, 1, 2]
- 组 1:通道
[3, 4, 5]
3. 计算均值和方差(Per-Group Statistics)
对 每个样本 n
和 每个组 g
,计算该组所有通道的均值和方差:
- 计算范围:
(C_per_group, H, W)
的所有值 - 公式:
μ n , g = 1 C-per-group × H × W ∑ c ∈ group g ∑ h , w x n , c , h , w \mu_{n,g} = \frac{1}{\text{C-per-group} \times H \times W} \sum_{c \in \text{group}_g} \sum_{h,w} x_{n,c,h,w} μn,g=C-per-group×H×W1c∈groupg∑h,w∑xn,c,h,w
σ n , g 2 = 1 C-per-group × H × W ∑ c ∈ group g ∑ h , w ( x n , c , h , w − μ n , g ) 2 \sigma_{n,g}^2 = \frac{1}{\text{C-per-group} \times H \times W} \sum_{c \in \text{group}_g} \sum_{h,w} (x_{n,c,h,w} - \mu_{n,g})^2 σn,g2=C-per-group×H×W1c∈groupg∑h,w∑(xn,c,h,w−μn,g)2
示例(x.shape = (2, 6, 3, 3)
,G=2
):
- 对
n=0, g=0
(第 0 个样本的第 0 组,通道[0,1,2]
),计算这3×3×3=27
个值的均值和方差。 - 对
n=0, g=1
(第 0 个样本的第 1 组,通道[3,4,5]
),同样计算均值和方差。 - 对
n=1
的所有组重复上述过程。
4. 归一化(Normalize)
使用计算出的 μ
和 σ²
对每个组进行归一化:
x ^ n , c , h , w = x n , c , h , w − μ n , g σ n , g 2 + ϵ \hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sqrt{\sigma_{n,g}^2 + \epsilon}} x^n,c,h,w=σn,g2+ϵxn,c,h,w−μn,g
其中:
ϵ
是一个极小值(默认1e-5
),防止除以零。g
是通道c
所属的组(g = c // (C // G)
)。
示例:
- 如果
c=4
,G=2
,C=6
,则g = 4 // 3 = 1
(属于第 1 组)。 - 用
μ_{n,1}
和σ_{n,1}^2
归一化通道4
的所有值。
5. 仿射变换(Scale & Shift)
GN 最后会应用可学习的缩放(γ
)和平移(β
)参数:
y n , c , h , w = γ c ⋅ x ^ n , c , h , w + β c y_{n,c,h,w} = \gamma_c \cdot \hat{x}_{n,c,h,w} + \beta_c yn,c,h,w=γc⋅x^n,c,h,w+βc
其中:
γ
和β
是可训练参数,形状为(C,)
(每个通道独立)。- 如果初始化时
affine=False
,则跳过此步(γ=1
,β=0
)。
PyTorch 代码模拟
以下是手动实现 nn.GroupNorm
的代码(简化版):
import torchdef group_norm(x, G, gamma, beta, eps=1e-5):N, C, H, W = x.shapeassert C % G == 0, "C must be divisible by G"C_per_group = C // G# Reshape to (N, G, C_per_group, H, W)x_grouped = x.view(N, G, C_per_group, H, W)# Compute mean and var per group (shape: N, G)mean = x_grouped.mean(dim=(2, 3, 4), keepdim=True)var = x_grouped.var(dim=(2, 3, 4), keepdim=True, unbiased=False)# Normalizex_norm = (x_grouped - mean) / torch.sqrt(var + eps)x_norm = x_norm.view(N, C, H, W)# Scale and shiftreturn gamma.view(1, C, 1, 1) * x_norm + beta.view(1, C, 1, 1)# 示例
x = torch.randn(2, 6, 3, 3) # N=2, C=6, H=W=3
G = 2
gamma = torch.ones(6) # 可训练参数 γ
beta = torch.zeros(6) # 可训练参数 β
y = group_norm(x, G, gamma, beta)
与 BN/LN/IN 的关系
方法 | 计算统计量的范围 | 适用场景 |
---|---|---|
BatchNorm (BN) | 整个 batch 的每个通道 (N,H,W) | 大 batch 训练 |
LayerNorm (LN) | 每个样本的所有通道 (C,H,W) | RNN/Transformer |
InstanceNorm (IN) | 每个样本的每个通道 (H,W) | 风格迁移/生成模型 |
GroupNorm (GN) | 每个样本的每组通道 (C//G, H, W) | 小 batch 训练 |
总结
nn.GroupNorm
的计算步骤:
- 分组:将
C
个通道分成G
组。 - 计算统计量:对每个样本的每个组计算
μ
和σ²
。 - 归一化:用
(x - μ) / sqrt(σ² + ϵ)
标准化。 - 仿射变换:应用
γ
和β
调整输出。
它的优势是 不依赖 batch size,适合小批量或单样本任务(如检测、分割)。