当前位置: 首页 > news >正文

【nn.GroupNorm】

组归一化(Group Normalization, GN)是一种深度学习中的归一化方法,旨在解决批量归一化(Batch Normalization, BN)在小批量(mini-batch)较小时性能下降的问题。GN 通过将通道(channels)分组进行归一化,不依赖于批量维度,因此对批量大小不敏感。

GN 的数学公式

给定输入张量 x x x,其维度为 ( N , C , H , W ) (N, C, H, W) (N,C,H,W)(分别表示批量大小、通道数、高度、宽度),GN 的计算步骤如下:

  1. 分组
    C C C 个通道分成 G G G 组,每组包含 C G \frac{C}{G} GC个通道(假设 ( C ) 能被 ( G ) 整除)。
    x n , g x_{n,g} xn,g表示第 n n n 个样本的第 g g g 组特征(维度为 C G × H × W \frac{C}{G} \times H \times W GC×H×W)。

  2. 计算均值和方差
    对每个样本的每个组分别计算均值和方差:
    μ n , g = 1 ∣ G ∣ ∑ c ∈ G ∑ h , w x n , c , h , w , \mu_{n,g} = \frac{1}{|\mathcal{G}|} \sum_{c \in \mathcal{G}} \sum_{h,w} x_{n,c,h,w}, μn,g=G1cGh,wxn,c,h,w,
    σ n , g 2 = 1 ∣ G ∣ ∑ c ∈ G ∑ h , w ( x n , c , h , w − μ n , g ) 2 , \sigma_{n,g}^2 = \frac{1}{|\mathcal{G}|} \sum_{c \in \mathcal{G}} \sum_{h,w} (x_{n,c,h,w} - \mu_{n,g})^2, σn,g2=G1cGh,w(xn,c,h,wμn,g)2,
    其中 G \mathcal{G} G表示当前组的所有通道索引, ∣ G ∣ = C G × H × W |\mathcal{G}| = \frac{C}{G} \times H \times W G=GC×H×W

  3. 归一化
    使用计算得到的均值和方差对输入进行归一化:
    x ^ n , c , h , w = x n , c , h , w − μ n , g σ n , g 2 + ϵ , \hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sqrt{\sigma_{n,g}^2 + \epsilon}}, x^n,c,h,w=σn,g2+ϵ xn,c,h,wμn,g,
    其中 ϵ \epsilon ϵ 是一个很小的常数(如 10 − 5 10^{-5} 105),用于数值稳定性。

  4. 仿射变换(可选)
    类似于 BN,GN 也可以引入可学习的缩放(scale)和平移(shift)参数:
    y n , c , h , w = γ c x ^ n , c , h , w + β c , y_{n,c,h,w} = \gamma_c \hat{x}_{n,c,h,w} + \beta_c, yn,c,h,w=γcx^n,c,h,w+βc,
    其中 γ c \gamma_c γc β c \beta_c βc 是可训练的参数,分别对应缩放和平移。

特点

  • 不依赖批量大小:GN 对每个样本独立计算统计量,适用于小批量或单样本训练(如目标检测、视频处理等任务)。
  • 组数 ( G ) 是超参数
    • 当 ( G=1 ),GN 退化为 Layer Normalization (LN)(对整个通道归一化)。
    • 当 ( G=C ),GN 退化为 Instance Normalization (IN)(每个通道单独归一化)。
  • 计算开销比 BN 略高:因为需要逐样本计算统计量,但比 LN 和 IN 更灵活。

PyTorch 实现

在 PyTorch 中,GN 可通过 torch.nn.GroupNorm 实现:

import torch.nn as nn# 假设输入通道数 C=32,组数 G=8
gn = nn.GroupNorm(num_groups=8, num_channels=32)
output = gn(input_tensor)  # input_tensor: (N, 32, H, W)

适用场景

  • 小批量训练(如 batch size < 16)。
  • 任务对批量统计量敏感(如检测、分割、生成模型等)。

GN 在 ResNetMask R-CNN 等模型中表现良好,尤其在批量较小时优于 BN。

在 PyTorch 中,nn.GroupNorm 的计算过程可以分为以下几个步骤,我们结合代码和数学公式详细说明。


1. 输入张量的形状

假设输入张量 x 的维度为 (N, C, H, W),其中:

  • N:batch size(样本数量)
  • C:通道数(channels)
  • HW:特征图的高度和宽度

例如,x.shape = (2, 6, 3, 3) 表示:

  • N = 2(2 个样本)
  • C = 6(6 个通道)
  • H = W = 3(3×3 的特征图)

2. 分组(Grouping)

nn.GroupNorm 的关键是将通道 C 分成 G 组(G 是超参数)。

  • 每组通道数:C_per_group = C // G
  • 要求 C 必须能被 G 整除,否则会报错。

示例
如果 C = 6G = 2,则:

  • 每组有 6 // 2 = 3 个通道
  • 组 0:通道 [0, 1, 2]
  • 组 1:通道 [3, 4, 5]

3. 计算均值和方差(Per-Group Statistics)

每个样本 n每个组 g,计算该组所有通道的均值和方差:

  • 计算范围:(C_per_group, H, W) 的所有值
  • 公式:
    μ n , g = 1 C-per-group × H × W ∑ c ∈ group g ∑ h , w x n , c , h , w \mu_{n,g} = \frac{1}{\text{C-per-group} \times H \times W} \sum_{c \in \text{group}_g} \sum_{h,w} x_{n,c,h,w} μn,g=C-per-group×H×W1cgroupgh,wxn,c,h,w
    σ n , g 2 = 1 C-per-group × H × W ∑ c ∈ group g ∑ h , w ( x n , c , h , w − μ n , g ) 2 \sigma_{n,g}^2 = \frac{1}{\text{C-per-group} \times H \times W} \sum_{c \in \text{group}_g} \sum_{h,w} (x_{n,c,h,w} - \mu_{n,g})^2 σn,g2=C-per-group×H×W1cgroupgh,w(xn,c,h,wμn,g)2

示例x.shape = (2, 6, 3, 3)G=2):

  • n=0, g=0(第 0 个样本的第 0 组,通道 [0,1,2]),计算这 3×3×3=27 个值的均值和方差。
  • n=0, g=1(第 0 个样本的第 1 组,通道 [3,4,5]),同样计算均值和方差。
  • n=1 的所有组重复上述过程。

4. 归一化(Normalize)

使用计算出的 μσ² 对每个组进行归一化:
x ^ n , c , h , w = x n , c , h , w − μ n , g σ n , g 2 + ϵ \hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sqrt{\sigma_{n,g}^2 + \epsilon}} x^n,c,h,w=σn,g2+ϵ xn,c,h,wμn,g
其中:

  • ϵ 是一个极小值(默认 1e-5),防止除以零。
  • g 是通道 c 所属的组(g = c // (C // G))。

示例

  • 如果 c=4G=2C=6,则 g = 4 // 3 = 1(属于第 1 组)。
  • μ_{n,1}σ_{n,1}^2 归一化通道 4 的所有值。

5. 仿射变换(Scale & Shift)

GN 最后会应用可学习的缩放(γ)和平移(β)参数:
y n , c , h , w = γ c ⋅ x ^ n , c , h , w + β c y_{n,c,h,w} = \gamma_c \cdot \hat{x}_{n,c,h,w} + \beta_c yn,c,h,w=γcx^n,c,h,w+βc
其中:

  • γβ 是可训练参数,形状为 (C,)(每个通道独立)。
  • 如果初始化时 affine=False,则跳过此步(γ=1β=0)。

PyTorch 代码模拟

以下是手动实现 nn.GroupNorm 的代码(简化版):

import torchdef group_norm(x, G, gamma, beta, eps=1e-5):N, C, H, W = x.shapeassert C % G == 0, "C must be divisible by G"C_per_group = C // G# Reshape to (N, G, C_per_group, H, W)x_grouped = x.view(N, G, C_per_group, H, W)# Compute mean and var per group (shape: N, G)mean = x_grouped.mean(dim=(2, 3, 4), keepdim=True)var = x_grouped.var(dim=(2, 3, 4), keepdim=True, unbiased=False)# Normalizex_norm = (x_grouped - mean) / torch.sqrt(var + eps)x_norm = x_norm.view(N, C, H, W)# Scale and shiftreturn gamma.view(1, C, 1, 1) * x_norm + beta.view(1, C, 1, 1)# 示例
x = torch.randn(2, 6, 3, 3)  # N=2, C=6, H=W=3
G = 2
gamma = torch.ones(6)  # 可训练参数 γ
beta = torch.zeros(6)  # 可训练参数 β
y = group_norm(x, G, gamma, beta)

与 BN/LN/IN 的关系

方法计算统计量的范围适用场景
BatchNorm (BN)整个 batch 的每个通道 (N,H,W)大 batch 训练
LayerNorm (LN)每个样本的所有通道 (C,H,W)RNN/Transformer
InstanceNorm (IN)每个样本的每个通道 (H,W)风格迁移/生成模型
GroupNorm (GN)每个样本的每组通道 (C//G, H, W)小 batch 训练

总结

nn.GroupNorm 的计算步骤:

  1. 分组:将 C 个通道分成 G 组。
  2. 计算统计量:对每个样本的每个组计算 μσ²
  3. 归一化:用 (x - μ) / sqrt(σ² + ϵ) 标准化。
  4. 仿射变换:应用 γβ 调整输出。

它的优势是 不依赖 batch size,适合小批量或单样本任务(如检测、分割)。

相关文章:

  • MQTT协议,EMQX部署,MQTTX安装学习
  • 苹果签名工具
  • 每天掌握一个Linux命令 - curl
  • 代码随想录算法训练营第60期第五十二天打卡
  • SpringBoot+Vue+微信小程序校园自助打印系统
  • [SWPUCTF 2023 秋季新生赛]Classical Cipher203分古典密码Base家族栅栏密码
  • 【xmb】内部文档148344596
  • RAG中的chunk以及评测方法
  • 辅助脚本-通用开发工作区目录结构生成脚本解析与实践指南
  • 5G 核心网 NGAP UE-TNL 偶联和绑定
  • C++学习-入门到精通【10】面向对象编程:多态性
  • 论坛系统(4)
  • C++核心编程_赋值运算符重载
  • 多线程(3)
  • 带sdf 的post sim 小结
  • azure web app创建分步指南系列之一
  • CMP401GSZ-REEL混合电压接口中的23ns延迟与±6V输入范围设计实现
  • const ‘不可变’到底是值不变还是地址不变
  • 痉挛性斜颈相关内容说明
  • 无人机桥梁3D建模、巡检、检测的航线规划
  • 企业做网站的注意什么/怎么自己做一个网址
  • 男朋友说是做竞彩网站维护的/域名查询注册商
  • 网站建设板块免费下载/企业如何开展网络营销
  • 2003网站服务器建设中/北大青鸟
  • 网站建设费用是多少钱/河北软文搜索引擎推广公司
  • 国外注册的域名国内做的网站/网站建设公司网站