当前位置: 首页 > news >正文

深度学习中的归一化技术详解:BN、LN、IN、GN

1. Batch Normalization (BN, 2015)

核心思想

  • 对Batch维度的每个特征通道进行归一化
  • 训练时用当前batch统计量,测试时用全局移动平均

计算步骤

# 输入x形状: [B, C, H, W] (CNN) 或 [B, D] (全连接)
mean = x.mean(dim=[0, 2, 3])  # 沿Batch/空间维度求均值
var = x.var(dim=[0, 2, 3], unbiased=False)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta  # 可学习参数γ,β

μₖ = (1/(B×H×W)) × Σ(xᵢⱼₖₗ) (i=1…B, j=1…H, l=1…W)
σₖ² = (1/(B×H×W)) × Σ(xᵢⱼₖₗ - μₖ)²
x̂ᵢⱼₖₗ = (xᵢⱼₖₗ - μₖ) / √(σₖ² + ε)
yᵢⱼₖₗ = γₖ × x̂ᵢⱼₖₗ + βₖ

特点

优点缺点
✅ 加速收敛
✅ 允许更大学习率
✅ 有正则化效果
❌ 依赖大batch(通常>16)
❌ 不适用于RNN/Dynamic NN

PyTorch实现

nn.BatchNorm2d(num_features)  # CNN
nn.BatchNorm1d(num_features)  # FC/RNN

2. Layer Normalization (LN, 2016)

核心思想

  • 对每个样本的所有特征进行归一化
  • 常用于Transformer和RNN

计算步骤

# 输入x形状: [B, T, D] (Transformer)
mean = x.mean(dim=-1, keepdim=True)  # 特征维度
var = x.var(dim=-1, keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta

μᵢ = (1/D) × Σ(xᵢⱼ) (j=1…D)
σᵢ² = (1/D) × Σ(xᵢⱼ - μᵢ)²
x̂ᵢⱼ = (xᵢⱼ - μᵢ) / √(σᵢ² + ε)
yᵢⱼ = γⱼ × x̂ᵢⱼ + βⱼ

特点

优点缺点
✅ 不依赖batch size
✅ 适合动态网络
❌ CNN效果不如BN

PyTorch实现

nn.LayerNorm(normalized_shape)  # normalized_shape=D

3. Instance Normalization (IN, 2017)

核心思想

  • 对每个样本的每个通道单独归一化
  • 风格迁移任务常用

计算步骤

# 输入x形状: [B, C, H, W]
mean = x.mean(dim=[2, 3], keepdim=True)  # 空间维度
var = x.var(dim=[2, 3], keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = gamma * x_hat + beta  # 可选

μᵢₖ = (1/(H×W)) × Σ(xᵢₖⱼₗ) (j=1…H, l=1…W)
σᵢₖ² = (1/(H×W)) × Σ(xᵢₖⱼₗ - μᵢₖ)²
x̂ᵢₖⱼₗ = (xᵢₖⱼₗ - μᵢₖ) / √(σᵢₖ² + ε)
yᵢₖⱼₗ = γₖ × x̂ᵢₖⱼₗ + βₖ (可选)

特点

优点缺点
✅ 保留样本间独立性
✅ 适合风格迁移
❌ 破坏通道间相关性

PyTorch实现

nn.InstanceNorm2d(num_features)

4. Group Normalization (GN, 2018)

核心思想

  • 将通道分组后对每组进行归一化
  • CNN小batch场景的BN替代方案

计算步骤

# 输入x形状: [B, C, H, W], 设groups=G
x = x.view(B, G, C//G, H, W)  # 分组
mean = x.mean(dim=[2, 3, 4], keepdim=True)
var = x.var(dim=[2, 3, 4], keepdim=True)
x_hat = (x - mean) / sqrt(var + eps)
out = x_hat.view(B, C, H, W) * gamma + beta
分组后形状: [B, G, C//G, H, W]

μᵢ₉ = (1/((C//G)×H×W)) × Σ(xᵢ₉ₖⱼₗ)
σᵢ₉² = (1/((C//G)×H×W)) × Σ(xᵢ₉ₖⱼₗ - μᵢ₉)²
x̂ᵢ₉ₖⱼₗ = (xᵢ₉ₖⱼₗ - μᵢ₉) / √(σᵢ₉² + ε)

恢复形状后:

yᵢₖⱼₗ = γₖ × x̂ᵢₖⱼₗ + βₖ

特点

优点缺点
✅ 小batch表现好
✅ 精度接近BN
❌ 计算量稍大

PyTorch实现

nn.GroupNorm(num_groups, num_channels)

5.对比总结

方法归一化维度适用场景Batch依赖
BN[B, H, W]大batch/CNN
LN[D]RNN/Transformer
IN[H, W]风格迁移/生成模型
GN[G, H, W]小batch CNN

代码示例(四种归一化对比)

import torch.nn as nn# 输入假设: [2, 6, 224, 224] (batch=2, channels=6)
bn = nn.BatchNorm2d(6)
ln = nn.LayerNorm([6, 224, 224])  # 全特征归一化
in = nn.InstanceNorm2d(6)
gn = nn.GroupNorm(num_groups=3, num_channels=6)  # 分2组

如何选择?

  1. CNN:优先尝试BN → batch<8时用GN
  2. RNN/Transformer:必选LN
  3. Style Transfer:首选IN
  4. 小batch CNN:GN+LN组合

📌 经验法则:当BN效果不佳时,根据任务特性尝试其他归一化方法


6. Transformer架构中的归一化标准方案

现代大语言模型普遍采用 Pre-LayerNorm 结构,即在注意力/FFN层之前进行归一化:

输入 → LayerNorm → Attention → 残差连接 → LayerNorm → FFN → 残差连接

6.1 ChatGPT (OpenAI GPT系列)

模型版本归一化方案关键细节
GPT-2LayerNorm经典Post-LN
GPT-3LayerNorm改为Pre-LN
GPT-4LayerNorm + 改进可能引入RMSNorm

特点

  • 始终坚持LayerNorm
  • 从Post-LN转向更稳定的Pre-LN结构

6.2 DeepSeek

模型版本归一化方案关键细节
DeepSeek-MoELayerNormPre-LN结构
DeepSeek-CoderLayerNorm代码模型同样架构

创新点

  • 在MoE架构中保持LayerNorm一致性
  • 对长上下文优化了Norm位置

6.3 Qwen (阿里通义千问)

模型版本归一化方案关键细节
Qwen-1.8BLayerNorm标准实现
Qwen-72BRMSNorm性能优化

技术演进

  • 大参数模型改用RMSNorm减少计算量
  • 保留LayerNorm的缩放偏移参数

6.4为什么不用BatchNorm?

所有主流LLM都避免使用BN,原因包括:

  1. 序列长度可变:BN需要固定维度,但文本长度动态变化
  2. 小batch推理:预测时batch_size=1,BN统计量失效
  3. 训练不稳定:文本数据的稀疏性导致BN方差估计不准

6.5 进阶变体:RMSNorm

新兴模型(如LLaMA、Qwen-72B)开始采用 RMSNorm(Root Mean Square Normalization):

def rms_norm(x, eps=1e-6):# 去均值操作(相比LayerNorm)return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps

RMS(x) = √((1/D) × Σ(xⱼ²) + ε)
yᵢ = (xᵢ / RMS(x)) × γᵢ

优势

  • 计算量减少约20%(适合超大模型)
  • 在Transformer中表现接近LayerNorm

6.6 模型实现对比表

模型归一化方案结构位置是否含β/γ
GPT-4LayerNormPre-LN
LLaMA-2RMSNormPre-LN
Qwen-72BRMSNormPre-LN
DeepSeek-MoELayerNormPre-LN

6.7关键结论

  1. LayerNorm仍是主流:90%以上的LLM使用
  2. Pre-LN成为标准:比原始Transformer的Post-LN更稳定
  3. RMSNorm是趋势:新模型为效率逐步转向RMSNorm
  4. 绝对不用BN:所有文本模型都避免BatchNorm
http://www.dtcms.com/a/274996.html

相关文章:

  • Kubernetes 高级调度特性
  • C语言:位运算
  • Redis 哨兵机制
  • 多代理系统(multi-agent)框架深度解析:架构、特性与未来
  • 无代码自动化测试工具
  • STM32G473串口通信-USART/UART配置和清除串口寄存器状态的注意事项
  • 隆重介绍 Xget for Chrome:您的终极下载加速器
  • 开源界迎来重磅核弹!月之暗面开源了自家最新模型 K2
  • 从延迟测试误区谈起:SmartPlayer为何更注重真实可控的低延迟?
  • gitee 代码仓库面试实际操作题
  • Cadence Virtuoso中如何集成Calibre
  • Java进阶---并发编程
  • 打造未来制造核心力:虚拟调试的价值与落地思路
  • YOLO-DETR如何提升小目标的检测效果
  • 【数据结构与算法】数据结构初阶:详解顺序表和链表(三)——单链表(上)
  • OpenCV实现感知哈希(Perceptual Hash)算法的类cv::img_hash::PHash
  • 商城网站建设实务
  • Ragflow-plus本地部署和智能问答及报告编写应用测试
  • 标准化模型格式ONNX介绍:打通AI模型从训练到部署的环节
  • C语言易错点(二)
  • C++包管理工具:conan2常用命令详解
  • JVM-----【并发可达性分析】
  • Android 12系统源码_分屏模式(一)从最近任务触发分屏模式
  • 微信小程序核心知识点速览
  • OpenCV图像基本操作:读取、显示与保存
  • OpenLLMetry 助力 LLM 应用实现可观测性
  • 1-Git安装配置与远程仓库使用
  • uniapp---入门、基本配置了解
  • springboot-2.3.3.RELEASE升级2.7.16,swagger2.9.2升级3.0.0过程
  • 猿人学js逆向比赛第一届第十九题