当前位置: 首页 > news >正文

训练模型时梯度出现NAN或者inf

参考:

训练模型时梯度出现NAN或者INF(禁用amp的不同level)_grad norm inf-CSDN博客

Got a nan loss and gradient norm when training swin-l on imagenet22k with O1 · Issue #82 · microsoft/Swin-Transformer

Got a nan loss and gradient norm when training swin-l on imagenet22k with O1 · Issue #82 · microsoft/Swin-Transformer

首先查找原因:找到是在哪一层断开了梯度

实现代码:

# loss 数值诊断
if torch.isnan(loss) or torch.isinf(loss):print("❌ Loss 出现 NaN/Inf")# 梯度诊断
for name, param in model.named_parameters():if param.grad is not None:if torch.isnan(param.grad).any():print(f"❌ {name} 梯度中存在 NaN")elif torch.isinf(param.grad).any():print(f"⚠️ {name} 梯度中存在 Inf")

排查可能原因:

1. 模型中存在未初始化或未更新的参数(层)

2. 除以0或者log引起

3.输入数据存在你nan或者inf

4. 学习率过大造成梯度不稳定

5.数据类型问题

由于梯度更新使用了AMP,加之具体问题:前1-2个epoch的grad norm出现 nan, 后面又稳定了,偶尔又会出现inf。合理怀疑是数据类型的问题。

PyTorch 的 AMP(自动混合精度) 默认支持动态切换精度。它会在前向和后向传播中自动判断是否切换为 float16 精度,以节省显存并加速计算。在使用 AMP 时,通常采用以下几种机制来选择精度:

按操作动态调整精度:AMP 会根据具体操作的数值稳定性来选择 float32 或 float16,对于稳定性较好的操作(如矩阵乘法)使用 float16,对精度要求较高的操作(如归一化)则保留 float32。

GradScaler 动态调整梯度缩放:AMP 默认使用 GradScaler 对梯度进行缩放,以避免因 float16 造成的数值下溢(过小梯度被舍去)。

这种自动化过程旨在最大程度保持数值稳定性,并降低显存需求。只需使用 torch.cuda.amp.autocast 上下文管理器和 GradScaler,AMP 就能完成动态精度切换

import torch
from torch.cuda.amp import GradScaler, autocast# ========= 训练前配置部分 =========
# 默认使用 AMP(float16),可通过命令行参数禁用或修改精度选项
use_amp = not args.disable_amp
amp_dtype = torch.float16  # 可选:torch.bfloat16, torch.float16if args.amp_opt_level == 'O0':use_amp = False
elif args.amp_opt_level == 'O1':amp_dtype = torch.float16
elif args.amp_opt_level == 'O2':amp_dtype = torch.float16
elif args.amp_opt_level == 'O3':amp_dtype = torch.float16
elif args.amp_opt_level == 'bf16':amp_dtype = torch.bfloat16
else:amp_dtype = torch.float16  # 默认 fallback# AMP梯度缩放器(避免数值下溢)
scaler = GradScaler(enabled=use_amp)# ========= 模型初始化 =========
model = MyModel()
model = model.to(device)# 如果使用 bfloat16,可考虑转换模型参数
if amp_dtype == torch.bfloat16 and use_amp:model = model.to(dtype=torch.bfloat16)# ========= 训练过程 =========
for epoch in range(num_epochs):model.train()for batch in dataloader:inputs, targets = batchinputs = inputs.to(device)targets = targets.to(device)if amp_dtype == torch.bfloat16 and use_amp:inputs = inputs.to(dtype=torch.bfloat16)targets = targets.to(dtype=torch.bfloat16)optimizer.zero_grad()with autocast(enabled=use_amp, dtype=amp_dtype):outputs = model(inputs)loss = criterion(outputs, targets)if use_amp:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()print(f"Epoch {epoch} finished.")# ========= 验证及日志 =========
# 可加入日志记录:记录grad norm是否为nan,或loss是否稳定
def log_gradient_norm(model):total_norm = 0.0for p in model.parameters():if p.grad is not None:param_norm = p.grad.data.norm(2)if torch.isnan(param_norm) or torch.isinf(param_norm):print("⚠️ Gradient contains NaN or Inf")total_norm += param_norm.item() ** 2total_norm = total_norm ** 0.5print(f"Gradient norm: {total_norm:.4f}")

解决方案:

1. 禁用自动混合精度(AMP):

如果你不依赖于 bfloat16 的性能优化,可以选择禁用 AMP。你可以在你的主函数中设置 --disable_amp 参数,或者在代码中直接注释掉与 GradScaler 和 autocast 相关的代码。这将避免因 bfloat16 引起的问题。

. --disable_amp
类型: 布尔型(action='store_true')
功能: 如果指定了这个参数,将会禁用 PyTorch 的自动混合精度功能。在训练过程中,这意味着模型将会使用全精度(通常是 float32)进行计算,而不使用混合精度。
适用场景: 在调试或遇到精度问题时,可以选择禁用 AMP。


2. 使用 bfloat16:

3.设置级别

模式推荐设置使用场景
禁用 AMP--disable_amp--amp-opt-level O0调试/排查数值问题最稳妥
安全混合精度--amp-opt-level O1(默认)高性能但更安全的float16使用方案
激进混合精度--amp-opt-level O2对性能要求较高,模型已测试数值稳定时使用
bfloat16--amp-opt-level bf16 + A100以上GPU更高稳定性的混合精度方案,适用于新硬件
强制float16--amp-opt-level O3实验阶段或推理速度要求极高但需谨慎
http://www.dtcms.com/a/321439.html

相关文章:

  • WiFi 核心概念与实战用例全解
  • git环境配置_笔记
  • [Linux]学习笔记系列 -- [arm[kernel]
  • modem上报SIM卡状态为unknown问题分析
  • 6_基于深度学习的火灾检测识别系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
  • 控制建模matlab练习13:线性状态反馈控制器-②系统的能控性
  • #C语言——刷题攻略:牛客编程入门训练(六):运算(三)-- 涉及 辗转相除法求最大公约数
  • JAVA,Maven聚合
  • 【记录】yumdownloader 和 yum install --downloadonly
  • Linux线程概念
  • 一洽客服系统:APP路由等级与路由条件设置
  • FreeRTOS入门知识(初识RTOS)(二)
  • NOIP普及组|2000T1进制转换
  • Spring Boot + Angular 实现安全登录注册系统:全栈开发指南
  • GPT-5 来了:和它改变世界的 982 天
  • 存储管理、XFS 增量备份恢复、LVM
  • Unity_VR_Pico开发手册
  • 防火墙安全作用及 firewalld 交互、端口配置
  • 基于嵌入式Linux RK3568 qt 车机系统开发
  • 前端单元测试最佳实践(一)
  • Linux系统编程Day9 -- gdb (linux)和lldb(macOS)调试工具
  • ubuntu安装gpu驱动
  • 中国的超算中心使用情况如何?是否算力过剩
  • vscode 配置
  • STM32学习笔记5-TIM定时器-1
  • windows安装Docker Desktop报错One prerequisite is not fulfilled.
  • CVPR2025-DEIM改进项目介绍
  • Linux 运维与优化的系统化思维:从内核到生产环境的全链路管理
  • 景区车辆监控调度管理平台建设方案
  • Jetpack Compose 常用控件