当前位置: 首页 > news >正文

PyTorch 的 F.scaled_dot_product_attention 返回Nan


“为什么 PyTorch 的 scaled_dot_product_attention 会输出 NaN?如何正确构造 Attention Mask”


引言:看似正常的 mask,为什么会引发 NaN?

在使用 F.scaled_dot_product_attention 构建跨模态或多源注意力时,我们常通过 attention_mask 控制每个 query 位置能看到哪些 key。但如果不小心构造出某些 query 对所有 key 都不可见的情况,就会在 softmax 中触发 NaN,进而让模型 loss 崩溃。

这个问题隐蔽却常见,且 PyTorch 不会自动容错,需要我们显式处理。


问题复现:全 -inf 行将导致 NaN

在 PyTorch 的 scaled attention 中:

output = scaled_dot_product_attention(query, key, value, attn_mask)

其中 attn_maskadditive mask,即:

  • 0.0: 表示该位置可见;
  • -inf: 表示该位置被屏蔽,不可 attend。

当某个 query 行的 mask 全为 -inf 时,softmax 输入类似于:

softmax([-inf, -inf, ..., -inf])[NaN, NaN, ..., NaN]

这将污染整个计算图,最终导致 loss 为 NaN。


产生这种情况的常见原因

这种情况经常发生在任务中存在大量 query(例如图像 patch、token、时间步)本身就不应该 attend 到任何 key,例如背景区域或 padding 区域。

因此,虽然逻辑合理,但仍然在数学上不合法


解决方案:fallback 解锁最后一个 key

为避免 NaN,可在转换 bool mask → float mask 时引入一个 fallback:

# attention_mask: [B, Q, K],bool 类型,True 表示“可以 attend”
attention_mask_float = torch.full_like(attention_mask, float('-inf'), dtype=query.dtype)
attention_mask_float.masked_fill_(attention_mask, 0.0)# fallback:避免某些 query 全为 -inf
all_inf_rows = (attention_mask_float == float('-inf')).all(dim=-1, keepdim=True)  # [B, Q, 1]
if all_inf_rows.any():last_key_idx = attention_mask_float.size(-1) - 1fix_mask = torch.arange(attention_mask_float.size(-1), device=attention_mask.device) == last_key_idxfix_mask = fix_mask.view(1, 1, -1)  # reshape for broadcastattention_mask_float = attention_mask_float.masked_fill(all_inf_rows & fix_mask, 0.0)

这样即便某个 query 原本完全不可见,也能保证 softmax 至少有一个有效分布。


可视化建议

可以使用 matplotlib.imshow 直接可视化 [Q, K] 的 mask 分布:

# 黑色:可见(0.0),白色:被 mask(-inf)
vis_mask = (attn_mask == 0.0).astype(np.uint8)
plt.imshow(vis_mask, cmap='Greys', aspect='auto')

可视化能帮助你快速定位全白 query 行,即潜在 NaN 风险点。


总结

条目建议
是否允许 query 全被屏蔽语义上允许,数学上不合法(需处理)
PyTorch 是否兜底否,需用户自己容错
是否应解锁一个 dummy key是,最安全的 fallback 机制
可否通过可视化排查是,黑白图可快速识别空行

http://www.dtcms.com/a/194136.html

相关文章:

  • 三格电子上新了——Modbus转IEC104网关
  • C42-作业练习
  • 速通RocketMQ配置
  • MySQL——3、数据类型
  • YOLOv8在单目向下多车辆目标检测中的应用
  • VsCode和AI的前端使用体验:分别使用了Copilot、通义灵码、iflyCode和Trae
  • CentOS系统中升级Python 3.12.2版本
  • 基于对抗性后训练的快速文本到音频生成:stable-audio-open-small 模型论文速读
  • 火语言RPA--EmpireV7下载发布
  • 【大模型面试每日一题】Day 20:大模型出现“幻觉”(Hallucination)的可能原因有哪些?如何从数据或训练层面缓解?
  • nosqlbooster pojie NoSQLBooster for MongoDB
  • 4.2.3 Thymeleaf标准表达式 - 5. 片段表达式
  • SAP ABAP 程序中归档数据读取方式
  • 在服务器上安装AlphaFold2遇到的问题(1)
  • 街景主观感知全流程(自建数据集+两两对比程序+Trueskill计算评分代码+训练模型+大规模预测)11
  • 在服务器上安装AlphaFold2遇到的问题(3)_cat: /usr/include/cudnn_version.h: 没有那个文件或目录
  • 【洗车店专用软件】佳易王洗车店多项目会员管理系统:一卡多用扣次软件系统实操教程 #扣次洗车管理软件
  • Spring框架(三)
  • 1688代采系统商品采集下单支付解决方案|官方API接口接入指南
  • npm cross-env工具包介绍(跨平台环境变量设置工具)
  • 机器学习第十五讲:决策树全面讲解:像玩“20个问题“游戏猜身份[特殊字符]
  • 国产linux系统(银河麒麟,统信uos)使用 PageOffice自定义Word模版中的数据区域
  • PyTorch深度学习框架60天进阶学习计划-第56天:大模型微调实践(二)
  • 【学习心得】Jupyter 如何在conda的base环境中其他虚拟环境内核
  • 动态IP赋能业务增效:技术解构与实战应用指南
  • Oracle数据库如何进行冷备份和恢复
  • 临床决策支持系统的提示工程优化路径深度解析
  • 使用IDEA开发Spark Maven应用程序【超详细教程】
  • Android framework 中间件开发(三)
  • docker(四)使用篇二:docker 镜像