当前位置: 首页 > news >正文

深入解析PyTorch中MultiheadAttention的参数key_padding_mask与attn_mask

1. 基本背景

在multiheadattention中存在两个mask,一个参数是key_padding_mask,另外一个是attn_mask,尽管这两个参数是被人们所熟知的填充掩码和注意力掩码,但是深度理解以便清晰区分对于深刻理解该架构非常重要。

2. 参数Key_padding_mask(关键填充掩码)

  • 用途:防止模型关注到输入序列中用 <pad> 填充的位置。
  • 场景:对变长输入进行 padding 后,避免注意力将注意力权重分配到 padding token 上。
  • 应用位置:在计算注意力时,对 所有 query 的 key 位置 进行屏蔽。

✅维度

# key_padding_mask shape: (batch_size, seq_len)

✅ 示例

key_padding_mask = torch.tensor([[False, False, True], [False, True, True]])
# 表示第一个样本第3个位置是pad,第二个样本第2,3个位置是pad

3. 参数Attn_mask(注意力掩码)

  • 用途:对注意力矩阵中任意 query-key 对的连接进行屏蔽,更灵活。
  • 场景:
    • Transformer 解码器中的 自回归遮蔽(causal mask)
    • 限定注意力只能在局部范围内滑动(局部注意力)
    • 自定义 mask,如节省计算或实验结构

✅ 维度

# [tgt_len, src_len](用于所有 batch 和 head)
# 或 [batch_size * num_heads, tgt_len, src_len](用于每个 head 的个性化 mask)

✅ 示例:causal mask

# 上三角为 True,代表“未来的信息被屏蔽”,用于解码器自回归。
tgt_len = 5
attn_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()

4. 工作流程中的区别⚠️⚠️⚠️

在计算 Q ∗ K T Q*K^T QKT之后:

  1. 先应用attn_mask(对齐注意力矩阵维度,屏蔽某些query-key配对);
  2. 再应用key_padding_mask(对每个样本的padding key屏蔽);
  3. 最后经过softmax处理

5. 类比理解

  • key_padding_mask 像是说:“这些 token 是 padding,不用关注它们。”
  • attn_mask 像是说:“这些 query-key 配对不允许有连接(比如未来的信息)。”

相关文章:

  • 分布式与集群:概念、区别与协同
  • disryptor和rabbitmq
  • RabbitMQ-如何选择消息队列?
  • 大语言模型(LLM)如何通过“思考时间”(即推理时的计算资源)提升推理能力
  • Java设计模式之组合模式:从入门到精通(保姆级教程)
  • 【NLP】37. NLP中的众包
  • Better Faster Large Language Models via Multi-token Prediction 原理
  • 【NLP】34. 数据专题:如何打造高质量训练数据集
  • femap许可与多用户共享
  • (二十二)Java File类与IO流全面解析
  • 怎么样进行定量分析
  • 在 Java MyBatis 中遇到 “操作数类型冲突: varbinary 与 float 不兼容” 的解决方法
  • python打卡day30@浙大疏锦行
  • 团队氛围紧张,如何提升工作积极性?
  • RSA(公钥加密算法)
  • token令牌
  • Image and depth from a conventional camera with a coded aperture论文阅读
  • day30python打卡
  • FPGA:高速接口JESD204B以及FPGA实现
  • 动态IP技术在跨境电商中的创新应用与战略价值解析
  • 王毅同德国外长瓦德富尔通电话
  • 网络直播间销售玩具盲盒被指侵权,法院以侵犯著作权罪追责
  • 中国旅马大熊猫“福娃”和“凤仪”启程回国
  • 马上评|科学红毯,让科学家成为“最亮的星”
  • 中国新闻发言人论坛在京举行,郭嘉昆:让中国声音抢占第一落点
  • 南宁一学校发生伤害案件,警方通报:嫌疑人死亡,2人受伤