Pytorch 报错-probability tensor contains either ‘inf‘, ‘nan‘ or element < 0 解决方案
Error information
/pytorch/aten/src/ATen/native/cuda/TensorCompare.cu:110: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.
模型生成错误: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions
错误原因分析
1. 概率张量包含无效值
错误信息 probability tensor contains either 'inf', 'nan' or element < 0 表明在生成过程中,模型计算出的概率分布包含了:
- inf(无穷大)
- nan(非数字)
- 负数值
2. 为什么设置 do_sample=True 会触发这个错误
当 do_sample=True 时,模型会:
- 计算词汇表上所有token的概率分布
- 根据概率分布进行采样选择下一个token
- 这个过程需要计算softmax概率
而 do_sample=False(贪婪解码)时:
- 直接选择概率最高的token
- 不需要计算完整的概率分布
- 避免了概率计算中的数值问题
3. 根本原因分析
模型权重或输入问题:
- 模型可能存在数值不稳定性
- 输入数据的某些特征可能导致logits计算出现极端值
- 模型在特定输入下的数值计算溢出
硬件或环境问题:
- GPU内存不足导致计算精度问题
- CUDA版本与PyTorch版本不兼容
- 模型加载时的精度设置(torch_dtype=torch.float16)可能导致数值不稳定
生成参数冲突:
- temperature=0.7 与 do_sample=True 的组合可能产生数值问题
- 当temperature较小时,某些token的概率可能趋近于0,导致数值不稳定
4. 为什么贪婪解码不会出错
贪婪解码(do_sample=False):
- 只选择最大概率的token
- 不需要计算完整的概率分布
- 避免了softmax计算中的数值问题
- 即使有数值问题,argmax操作相对稳定
5. 具体触发机制
- 模型计算logits(未归一化的概率)
- 应用temperature缩放:logits / temperature
- 计算softmax概率分布
- 如果logits包含极端值,temperature缩放后可能产生:
- 过大的正值 → inf
- 过小的负值 → 数值下溢
- 0/0 或 inf/inf → nan
6. 解决方案思路
短期解决:
- 继续使用 do_sample=False(贪婪解码)
- 调整其他参数如 temperature、top_p 等
长期解决:
- 检查模型权重是否完整下载
- 使用更高的数值精度(torch.float32)
- 添加数值稳定性检查
- 使用更稳定的采样方法(如top-k、top-p采样)
这个错误本质上是深度学习模型在数值计算中的常见问题,特别是在使用半精度浮点数(float16)和采样生成时更容易出现。
解决方案
1. 数值稳定性问题
- 模型在采样模式下存在数值计算不稳定
- Float16精度可能导致logits计算时出现inf/nan值
- CUDA内核错误表明在概率计算阶段出现了数值溢出
2. 模型状态问题
- 模型可能没有正确加载或初始化
- 权重文件可能损坏或不完整
- 设备内存不足导致计算精度问题
3. 输入数据问题
- 输入token序列可能包含特殊字符或格式问题
- 输入长度可能超出模型处理能力
- 输入内容可能触发模型的数值不稳定区域
逐一排查,发现最终原因
- Float16精度可能导致logits计算时出现inf/nan值
解决方法:
# 将模型加载时的精度从float16改为float32
torch_dtype=torch.float32 # 而不是torch.float16
即可生成正常prediction。