Transformer 训不动:注意力 Mask 用反 / 广播错位
Transformer 训不动:注意力 Mask 用反 / 广播错位,PAD 被看到导致 Loss 不降(一次从 nn.MultiheadAttention 到 SDPA 的排障实录)
自选日志 · 深度学习代码
场景:做中文小型 GPT,明明 Batch/学习率都合理,但 loss 长期不降或震荡,ppl 比基线还差。单句推理“看起来能写”,但验证集一直拉胯。最后定位是 注意力 Mask 写错:把 PAD/未来位当可见、或把 0/1 语义用反,甚至让 Mask 在 [B,T] 和 [B,h,T,T] 间 广播错位。
本文记录完整复盘:现象 → MRE → 排查 → 修复,并给出 PyTorch 2.x 推荐写法(F.scaled_dot_product_attention,布尔 Mask 更稳)。
技术环境
Python 3.10 / PyTorch 2.2.x(CUDA 12.x)
任务:Causal LM(解码器-only Transformer)
组件:nn.MultiheadAttention(早期) / F.scaled_dot_product_attention(SDPA,修复后)
Bug 现象
训练 loss 基本不降(或震荡 3.8~4.2),grad_norm 正常、无 NaN。
Batch=1 还能降一点,一到 Batch≥8 就“卡住”。
评估 ppl 异常、EOS 附近预测错乱;可视化注意力发现 PAD/未来位置有高权重。
换优化器/学习率没用;关掉 Dropout/LayerNorm 也无改善。
最小可复现(错误版)
错误 1:把 0/1 Mask 直接加到注意力分数上
wrong_attn_add_mask.py —— 不要照抄
import torch, torch.nn.functional as F
B, h, T, d = 4, 8, 16, 64
q = torch.randn(B, h, T, d)
k = torch.randn(B, h, T, d)
v = torch.randn(B, h, T, d)
pad_id = 0
tokens = torch.randint(1, 100, (B, T))
tokens[:, -3:] = pad_id # 尾部 PAD
❌ 错:用 1 表示可见,0 表示屏蔽,然后直接“加”到分数
(很多人从 HF attention_mask 误迁移)
vis_mask = (tokens != pad_id).float() # [B, T],1=可见,0=PAD
scores = (q @ k.transpose(-1, -2)) / d**0.5 # [B,h,T,T]
scores = scores + vis_mask[:, None, None, :] # ⛔ 广播后只是+0或+1,PAD仍可见
attn = torch.softmax(scores, dim=-1) # PAD 被分到概率
out = attn @ v
错误 2:key_padding_mask 语义用反(nn.MultiheadAttention)
import torch, torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=dh, num_heads=h, batch_first=True)
x = torch.randn(B, T, dh)
❌ 传入“1=可见,0=PAD”的 mask,但 MHA 的 key_padding_mask 语义是 True=PAD
attention_mask = (tokens != pad_id) # True=可见
out, _ = mha(x, x, x, key_padding_mask=attention_mask) # ⛔ 语义反了
两类错误共同后果:模型能看见 PAD/未来位,梯度里掺进“无意义对齐”,尤其 Batch 一大,PAD 数量成规模,学习被噪声淹没。
排查步骤(真实过程)
1)立即打印 Mask 统计与形状
print(“vis_mask sum per batch:”, (tokens != pad_id).sum(dim=1)) # 看看 PAD 比例
print(“scores:”, scores.shape) # 期望 [B,h,T,T]
print(“mask:”, vis_mask[:,None,None,:].shape) # [B,1,1,T]
若 loss 对 Batch 极度敏感,先怀疑 PAD 处理。
2)在一条样本上可视化一行注意力
i, t = 0, T-1 # 看最后一个非 PAD 位置的注意力
print(“attn to PAD tail:”, attn[i, 0, t, -4:])
期望 ~0;若出现明显权重,基本就是 mask 错了
3)切到 SDPA + 布尔 Mask 验证
PyTorch 2 的 SDPA 支持 布尔 Mask(True=屏蔽),可以快速验证你的语义是否写对:
from torch.nn.functional import scaled_dot_product_attention as sdpa
pad_mask = (tokens == pad_id) # True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1) # True=未来位
attn_mask = pad_mask[:, None, None, :] | causal[None, None, :, :] # 布尔或
out = sdpa(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
若这样 loss 明显变好,你的旧实现 99% 是 mask 语义/广播错了
修复方案(两条主线)
✅ 推荐:统一使用布尔 Mask + SDPA
布尔 Mask 在 AMP/半精度下更稳(免去 -1e9 精度问题)。
import torch
import torch.nn.functional as F
def causal_pad_mask(tokens: torch.Tensor, pad_id: int):
B, T = tokens.shape
pad_mask = (tokens == pad_id) # [B,T], True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool, device=tokens.device), 1)
# [B,1,1,T] | [1,1,T,T] → [B,1,T,T], True=屏蔽
attn_mask = pad_mask[:, None, None, :] | causal[None, None, :, :]
return attn_mask
def attention(q, k, v, tokens, pad_id=0, dropout_p=0.0):
mask = causal_pad_mask(tokens, pad_id) # 布尔
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_p)
✅ 使用 nn.MultiheadAttention 时的正确语义
key_padding_mask: True=PAD/忽略(与 HF 的 attention_mask 语义相反)
attn_mask: [L,S] 或 [B⋯,L,S],布尔 True=屏蔽 或 加性 -inf。
mha = nn.MultiheadAttention(embed_dim=d*h, num_heads=h, batch_first=True)
key_padding_mask = (tokens == pad_id) # True=PAD
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), 1)
out, _ = mha(x, x, x, key_padding_mask=key_padding_mask, attn_mask=causal)
不要把 0/1 mask 直接“加到分数上”;若用加性 mask,请用 -inf 或 -torch.finfo(scores.dtype).max,且注意 半精度溢出。
额外修补:HF 栈 vs. 原生 PyTorch 的语义差异
HF attention_mask:1=keep(可见)、0=pad(屏蔽),模型内部会转成需要的语义。
原生 PyTorch:布尔 mask 通常 True=屏蔽。
迁移时最容易“翻译错”:
key_padding_mask = (attention_mask == 0) ✅
不要把 HF 的 0/1 **“直接加”**到 scores 上。
验证
修复后,loss 在前几千步明显下降,ppl 接近基线;
可视化注意力:PAD 与未来位的权重 ≈ 0(数值 < 1e-4);
Batch 从 1 → 64,收敛曲线单调更稳,不再出现“大 batch 更差”的反常。
自检脚本(建议 PR 必跑)
def assert_attn_mask(mask: torch.Tensor, q: torch.Tensor, k: torch.Tensor):
# 1) 形状可广播到 [B,h,Tq,Tk]?
B = q.shape[0]
Tq, Tk = q.shape[-2], k.shape[-2]
try:
torch.broadcast_shapes(mask.shape, (B, 1, Tq, Tk))
except RuntimeError as e:
raise AssertionError(f"attn_mask shape {mask.shape} cannot broadcast to [B,1,Tq,Tk]") from e
# 2) 语义:True=屏蔽(布尔),或加性掩码≤0
if mask.dtype == torch.bool:
assert mask.any() or (~mask).any(), “mask seems all True/False, check pipeline”
else:
assert mask.max() <= 0, “additive mask must be <=0; use 0 for keep and -inf for block”
def sanity_check_pad_attention(attn: torch.Tensor, tokens: torch.Tensor, pad_id: int):
# attn: [B,h,T,T] softmax 后
pad_cols = (tokens == pad_id) # [B,T]
w_to_pad = attn[…, pad_cols].mean().item()
assert w_to_pad < 1e-4, f"attention to PAD too high: {w_to_pad:.2e}"
避坑清单(Checklist)
统一布尔 Mask 语义:True=屏蔽(推荐 SDPA)。
Causal:torch.triu(torch.ones(T,T,bool),1),与 PAD 取并集。
不要用 0/1 mask 直接“加”到分数上(尤其半精度)。
nn.MultiheadAttention.key_padding_mask:True=PAD;HF 的 attention_mask 则相反(1=keep)。
广播检查:mask 能否广播到 [B,1,Tq,Tk]?
大 Batch 更差 → 首查 PAD/Mask;逐样本可视化一行注意力看 PAD 权重。
Loss ignore_index ≠ 注意力 Mask:忽略标签不会阻止模型看 PAD。两者都要。
结语
很多“玄学不收敛”不是优化器的锅,而是 Mask 语义/形状 这类“工程正确性”。
把 布尔 Mask + SDPA 作为默认姿势,写上 自检脚本,你会在 5 分钟内把这类问题钉死。需要的话,我可以提供一个最小可运行的 toy-gpt 脚本(含错误版与修复版、注意力热力图对比),方便在团队里做“反例演示”。