Qwen3_moe模型代码解析
Qwen3_moe模型代码解析
1) 顶层:Qwen3MoeModel.forward
(Embedding → 多层解码器 → RMSNorm)
要点
- 位置编码:MoE 是 1D RoPE(
position_ids
形状(1,B,S)
)。 - 掩码:根据配置可能是标准因果或滑窗因果。
- Cache:默认用
DynamicCache
(用于增量解码)。
2) 解码器层:Qwen3MoeDecoderLayer.forward
(Self-Attn → 残差 → MoE/MLP → 残差)
要点
- MoE 层与普通 MLP 层互斥;MoE 层额外产生
router_logits
(被上层输出记录用于负载均衡辅助损失)。
3) 注意力:Qwen3MoeAttention.forward
(Q/K/V → RoPE → 注意力 → 合并头)
要点
- 与 Qwen2 系列类似,GQA:
n_heads = num_attention_heads
,n_kv = num_key_value_heads
,通过repeat_kv
对齐。 - 层内对 Q/K 施加 RMSNorm(按头维) 再做 RoPE(这是 Qwen3 MoE 和一些实现的一个小差别)。
4) 稀疏 MoE:Qwen3MoeSparseMoeBlock.forward
(Gating → Top-k 路由 → 专家并行 → 汇聚)
要点
- 该实现是token-level routing;每 token 选 top-k 专家;支持
norm_topk_prob
对 top-k 权重归一化。 - 通过
index_add_
将各专家输出按原 token 位置汇聚。
5) RoPE:Qwen3MoeRotaryEmbedding.forward
(1D 位置 → cos/sin)
要点
- 与标准 1D RoPE 一致(没有多模态 3D 拆段)。
dynamic_rope_update
允许动态扩展(取决于rope_scaling
策略)。
6) 语言建模头:Qwen3MoeForCausalLM.forward
(LM Head & Router Loss)
要点
logits_to_keep
:只在末 K 个时间步计算lm_head
,显存友好。aux_loss
:负载均衡损失,鼓励专家使用更均匀。
7) 掩码构造与缓存(顶层)
8) 形状清单(常用变量)
-
B
: batch size;S
: 当前序列长度;H
: hidden_size;V
: vocab_size -
头部:
n_heads = num_attention_heads
,n_kv = num_key_value_heads
,d = head_dim = H / n_heads
-
注意力中:
- Q:
(B,n_heads,S,d)
;K/V:(B,n_kv,S,d)
→repeat_kv
→(B,n_heads,S,d)
- 权重
(B,n_heads,S,S)
;输出(B,S,H)
- Q:
9) 常见坑与对策
- inputs 选择:
(input_ids is None) XOR (inputs_embeds is not None)
必须成立,否则抛错。 - RoPE 维度:
position_ids
必须(1,B,S)
;若用 cache,需要正确设置cache_position
使位置连续。 - 滑窗注意力:窗口
W
太小会影响长程依赖;太大则近似全因果。确保与训练/推理对齐。 - MoE 路由:
num_experts_per_tok (top_k)
影响吞吐与均衡;norm_topk_prob=True
时要注意与训练策略匹配。 - 负载均衡损失:
output_router_logits=True
时才会收集所有层的router_logits
;注意与attention_mask
一起计算避免 padding 干扰。 - 精度:注意力
softmax
强制float32
再 cast 回来,避免数值不稳。
10) 端到端数字化算例(便于核对)
假设:
B=2, S=128, H=4096, V=151936
;n_heads=32 → d=128
;n_kv=8 → num_key_value_groups=4
;sliding_window=None
(全因果);use_cache=True
首次前向past=None
;- 第 4 层是 MoE 层:
num_experts=8, top_k=2, moe_intermediate_size=11008
,其它层为密集 MLP:intermediate_size=11008
。
流程
-
inputs_embeds = embed_tokens(input_ids)
→(2,128,4096)
-
cache_position = [0..127]
,position_ids=(1,2,128)
-
causal_mask=(2,1,128,128)
上三角 -inf -
进入第 1 层:
- Q/K/V 线性:
(2,128,4096) → Q:(2,128,4096) K/V:(2,128,1024)
- 视图→
Q:(2,32,128,128)
;K/V:(2,8,128,128)
→repeat_kv
→(2,32,128,128)
- RoPE:
apply_rotary_pos_emb
- 注意力:权重
(2,32,128,128)
→ 输出(2,32,128,128)
→ 合并头(2,128,4096)
- 残差 + MLP (SwiGLU):
(2,128,4096)
→(2,128,11008)
→(2,128,4096)
- Q/K/V 线性:
-
第 4 层(MoE):
- Gate:
(B*S,H)=(256,4096) → (256,8)
→ softmax → top2 - 对被命中专家的 tokens 送入各自
MLP_e
:(N_e,4096)→(N_e,11008)→(N_e,4096)
,乘以各 token 对应权重 index_add_
汇聚回(256,4096)
→ reshape(2,128,4096)
- 残差
router_logits
记录(供 loss)
- Gate:
-
L 层结束 →
RMSNorm
→last_hidden_state=(2,128,4096)
-
lm_head:(4096→V)
只算末K=32
步 →logits=(2,32,V)
-
如有
labels
:交叉熵 + 若output_router_logits=True
再加aux_loss
(乘以router_aux_loss_coef
)。