CS336第三课
让大语言模型更快、更稳、更省:从架构到训练的工程化实践
1. 架构总览:三件“基础武器”
1.1 RMSNorm:内存与计算同等重要
• 为什么:LayerNorm 计算均值与方差,较重;RMSNorm 只用均方根(RMS),更轻、更稳,对长序列更友好。
• 工程收益:少一次均值中心化;在大 batch/长序列下显存访问更省、吞吐更稳。
• 实现提示:eps 建议 1e-6 ~ 1e-5;混合精度下用 float32 计算 RMS。
1.2 门控前馈:GLU/GeGLU/ReGLU/SwiGLU
• 思想:为 MLP 隐层引入门控,把“值流”(value)与“门”(gate)分离(GeGLU / ReGLU / SwiGLU 等)。
• 额外矩阵 V:在部分实现中新引入 V(门或值支路),让输出维度缩到原来的 ~2/3,以参数换算力,提升表示/稳定性。
• 选择建议:SwiGLU/GeGLU 通常更稳;ReGLU 更省算。
代码示例:RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
# x: [..., d]
scale = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x * scale * self.weight
代码示例:门控前馈(SwiGLU/GeGLU风格)
class GatedFFN(nn.Module):
def __init__(self, d_model, d_ff, act='silu'):
super().__init__()
self.w = nn.Linear(d_model, d_ff, bias=False) # value
self.v = nn.Linear(d_model, d_ff, bias=False) # gate
self.o = nn.Linear(d_ff, d_model, bias=False)
self.act = nn.SiLU() if act=='silu' else nn.GELU()
def forward(self, x):
h_val = self.w(x)
h_gate = self.act(self.v(x))
return self.o(h_val * h_gate)
1.3 Attention 与 MLP 的串行 vs 并行
• 串行(标准 Transformer):x → Attn → Res → MLP → Res。
• 并行:对同一层输入同时计算 Self-Attention 与 MLP,最后在残差流中相加融合。
• 收益:降低关键路径时延,推理更易跑满。
2. 位置编码:为什么是 RoPE?
• RoPE 把高维向量拆成二维子空间做相位旋转;相对位置直接体现在 Q·K 内积里;不依赖绝对 index,利于长上下文扩展。
• 实现要点:只对 Q/K 旋转,V 不旋转;长上下文扩展(NTK/插值/分段)要关注困惑度随长度曲线。
3. 注意力实现:形状、FlashAttention 与 KV 缓存
• 形状:q/k/v 常 reshape 为 [B, heads, seq, head_dim];确保连续内存便于 FlashAttention。
• KV cache:保留历史 K/V,每步只算新 token 的注意力,推理复杂度近似 O(T)。
• MQA/GQA:减少 KV 头数以降显存与带宽,质量—效率折中更优。
4. 训练稳定性:从数值到损失
• Softmax 数值雷区:QK-norm(对 Q·K 归一)、z-loss(惩罚极端尖锐分布)、logit soft-capping(软截断)。
• 偏置与归一化:尽量去 bias;预训练少正则/谨慎权重衰减(与优化器强耦合)。
5. 宽与深、比例与超参:性能取舍的“黄金区间”
• 宽>深通常更稳更高效(更吃到并行与 FlashAttention 回报)。
• FFN 比例:传统 4×;若用门控(SwiGLU/GeGLU),等效可缩到 ~2.66×。
• 头数/头维:head_dim × num_heads ≈ d_model;GQA/MQA 让 num_kv_heads < num_heads 以省显存。
6. 计算效率:算强比、并行与内存访问
• 提升算强比:FlashAttention(块化、寄存器复用、避免大中间张量落显存)。
• 并行:同层 Attention 与 MLP 并行;跨层流水/张量并行;KV cache 预分配与就地运算。
7. 端到端层设计:一层长啥样?(并行层示例)
class ParallelBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, num_kv_heads=None):
super().__init__()
self.norm = RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, num_kv_heads=num_kv_heads, bias=False) # Q/K/V bias=False
self.ffn = GatedFFN(d_model, d_ff, act='silu')
def forward(self, x, cache=None, rope=None):
h = self.norm(x)
a = self.attn(h, cache=cache, rope=rope) # 内部对 Q/K 施加 RoPE
m = self.ffn(h)
return x + a + m # 并行分支在残差处相加
8. 训练与推理:一页 Checklist
训练:RMSNorm 全覆盖;去 bias;SwiGLU/GeGLU,d_ff=2.66–4×;RoPE;QK-norm/z-loss;慎用权重衰减;混合精度+grad_clip;数据去重与长短混合。
推理:KV cache;GQA/MQA;FlashAttention;并行层降 p99;量化/蒸馏;fuse/就地运算提升算强比。
9. 常见坑位与定位建议
1) 长上下文不稳:核对 RoPE 与扩窗策略;看困惑度随长度曲线。
2) 梯度爆炸/NaN:检查 QK-norm/z-loss、grad_clip、fp32 master weights;降学习率。
3) 吞吐掉点:定位 KV cache 申请与拷贝;核对 FlashAttention 的张量维序与步幅。
4) 显存超限:优先 GQA/MQA、缩小 FFN(2.66×)、门控替代普通 GELU-FFN。
5) 指令微调样本效率差:检查数据去重、分布均衡、warmup 计划。
10. 结语
这套“RMSNorm + 门控 FFN + 并行层/FlashAttention/GQA + RoPE + QK-norm/z-loss + KV cache”的组合拳,核心是把“数值稳定性”和“显存访问”当成一等公民,在同等成本下拿到更快的收敛与更稳的推理。