layer norm和 rms norm 对比
Layer norm
# Layer Norm 公式
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True)
output = (x - mean) / sqrt(var + eps) * gamma + beta
特点:
- 减去均值(去中心化)
- 除以标准差(标准化)
- 包含可学习参数 gamma 和 beta
- 计算复杂度相对较高
RMS Norm(Root Mean Square归一化):
# RMS Norm 公式
rms = sqrt(mean(x²))
output = x / rms * gamma
特点:
- 不减去均值(保持中心)
- 只除以RMS值
- 只有一个可学习参数 gamma
- 计算更简单高效
对比
代码对比
import torch
import torch.nn as nnclass LayerNorm(nn.Module):def __init__(self, dim, eps=1e-6):super().__init__()self.gamma = nn.Parameter(torch.ones(dim))self.beta = nn.Parameter(torch.zeros(dim))self.eps = epsdef forward(self, x):mean = x.mean(-1, keepdim=True)var = x.var(-1, keepdim=True, unbiased=False)return (x - mean) / torch.sqrt(var + self.eps) * self.gamma + self.betaclass RMSNorm(nn.Module):def __init__(self, dim, eps=1e-6):super().__init__()self.gamma = nn.Parameter(torch.ones(dim))self.eps = epsdef forward(self, x):rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)return x / rms * self.gamma