深入解析PyTorch中MultiheadAttention的参数key_padding_mask与attn_mask
1. 基本背景
在multiheadattention中存在两个mask,一个参数是key_padding_mask,另外一个是attn_mask,尽管这两个参数是被人们所熟知的填充掩码和注意力掩码,但是深度理解以便清晰区分对于深刻理解该架构非常重要。
2. 参数Key_padding_mask(关键填充掩码)
- 用途:防止模型关注到输入序列中用 <pad> 填充的位置。
- 场景:对变长输入进行 padding 后,避免注意力将注意力权重分配到 padding token 上。
- 应用位置:在计算注意力时,对 所有 query 的 key 位置 进行屏蔽。
✅维度
# key_padding_mask shape: (batch_size, seq_len)
✅ 示例
key_padding_mask = torch.tensor([[False, False, True], [False, True, True]])
# 表示第一个样本第3个位置是pad,第二个样本第2,3个位置是pad
3. 参数Attn_mask(注意力掩码)
- 用途:对注意力矩阵中任意 query-key 对的连接进行屏蔽,更灵活。
- 场景:
- Transformer 解码器中的 自回归遮蔽(causal mask)
- 限定注意力只能在局部范围内滑动(局部注意力)
- 自定义 mask,如节省计算或实验结构
✅ 维度
# [tgt_len, src_len](用于所有 batch 和 head)
# 或 [batch_size * num_heads, tgt_len, src_len](用于每个 head 的个性化 mask)
✅ 示例:causal mask
# 上三角为 True,代表“未来的信息被屏蔽”,用于解码器自回归。
tgt_len = 5
attn_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()
4. 工作流程中的区别⚠️⚠️⚠️
在计算 Q ∗ K T Q*K^T Q∗KT之后:
- 先应用
attn_mask
(对齐注意力矩阵维度,屏蔽某些query-key配对); - 再应用
key_padding_mask
(对每个样本的padding key屏蔽); - 最后经过
softmax
处理
5. 类比理解
- key_padding_mask 像是说:“这些 token 是 padding,不用关注它们。”
- attn_mask 像是说:“这些 query-key 配对不允许有连接(比如未来的信息)。”