当前位置: 首页 > news >正文

【Debug日志| 随机下降】

断点续训越来越差?未恢复优化器/调度器/GradScaler 状态导致的收敛倒退(含可复现实验与通用检查点模板)

在我们进行模型训练的过程中,可能会遇到这么一种情况:从头训练一切正常,但一旦中途断点续训,loss 开始抖、准确率掉、甚至直接发散。数据与代码未改,唯一不同是“加载了上次的模型权重继续训练”。本篇复盘可复现实验、定位方法与可直接落地的保存/恢复模板。

❓ Bug 现象

  • 断点续训后,学习率曲线突然跳变(回到 warmup 高位或峰值附近)。
  • 同等步数下,loss 明显高于从头训练,准确率短时回退。
  • AMP 训练里偶发首批 NaN,或几步内 loss 急剧波动。
  • 打印优化器状态字典发现为空,或者调度器 last_epoch 与预期不符。

📽️ 场景复现

保存为 resume_bug_demo.py,按注释运行两段即可在 CPU 复现。

import argparse, math, os, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)def make_loader(n=4096, bs=64):X = torch.randn(n, 10)y = (X[:, 0] + 0.6 * X[:, 1] > 0).long()ds = torch.utils.data.TensorDataset(X, y)return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)class TinyNet(nn.Module):def __init__(self):super().__init__()self.m = nn.Sequential(nn.Linear(10,64), nn.ReLU(), nn.Linear(64,2))def forward(self, x): return self.m(x)def build_scheduler(optimizer, total_steps, warmup_steps=50):def lr_lambda(step):if step < warmup_steps: return (step + 1) / warmup_stepst = (step - warmup_steps) / max(1, total_steps - warmup_steps)return 0.5 * (1 + math.cos(math.pi * min(1.0, t)))return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)def save_ckpt(path, model, optimizer, scheduler, scaler, global_step):os.makedirs(os.path.dirname(path), exist_ok=True)torch.save({"model": model.state_dict(),"optimizer": optimizer.state_dict(),"scheduler": scheduler.state_dict(),"scaler": scaler.state_dict() if scaler is not None else None,"global_step": global_step,}, path)def load_ckpt(path, model, optimizer=None, scheduler=None, scaler=None):ckpt = torch.load(path, map_location="cpu")model.load_state_dict(ckpt["model"])if optimizer is not None and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"])if scheduler is not None and "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"])if scaler is not None and ckpt.get("scaler") is not None: scaler.load_state_dict(ckpt["scaler"])return ckpt.get("global_step", 0)def train(phase, bug, steps):device = "cpu"model = TinyNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)scaler = None  # GPU+AMP 时可替换为 GradScaler()total_steps = 400scheduler = build_scheduler(optimizer, total_steps)loader = make_loader()it = iter(loader)global_step = 0ckpt_path = "ckpts/demo.pt"if phase == "resume":if bug:_ = load_ckpt(ckpt_path, model)   # 只加载模型,错误示范global_step = 0else:global_step = load_ckpt(ckpt_path, model, optimizer, scheduler, scaler)for _ in range(steps):try:   x, y = next(it)except StopIteration:it = iter(loader); x, y = next(it)logits = model(x)loss = F.cross_entropy(logits, y)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()scheduler.step()global_step += 1if (global_step % 25) == 0:lr = optimizer.param_groups[0]["lr"]acc = (logits.argmax(1) == y).float().mean().item()tag = f"[{phase}|{'BUG' if bug else 'OK'}]"print(f"{tag} step={global_step:04d} lr={lr:.4f} loss={loss.item():.3f} acc={acc:.3f}")if phase == "pretrain":save_ckpt(ckpt_path, model, optimizer, scheduler, scaler, global_step)if __name__ == "__main__":ap = argparse.ArgumentParser()ap.add_argument("--phase", choices=["pretrain","resume"], required=True)ap.add_argument("--bug", choices=["on","off"], default="on")ap.add_argument("--steps", type=int, default=200)args = ap.parse_args()train(args.phase, args.bug=="on", args.steps)

运行方式

# 阶段1:从头训练并保存检查点
python resume_bug_demo.py --phase pretrain --steps 200# 阶段2a:错误续训(只加载模型权重)
python resume_bug_demo.py --phase resume --bug on --steps 200# 阶段2b:正确续训(加载模型+优化器+调度器+scaler+全局步)
python resume_bug_demo.py --phase resume --bug off --steps 200

后果

  • 错误续训时 lr 会回到 warmup 峰值或高位,loss 短时恶化,acc 降低。
  • 正确续训时 lr 连续,loss/acc 曲线平滑延续 pretrain 末尾的趋势。
  • 若使用 AMP 且未恢复 GradScaler,首批更容易出现溢出或梯度无效。

Debug 过程

1️⃣ 检查优化器与调度器状态
打印优化器 state 是否为空、调度器 last_epoch 是否连续。

print("optimizer_has_state:", any(len(s)>0 for s in optimizer.state.values()))
print("scheduler_last_epoch:", getattr(scheduler, "last_epoch", None))
print("global_step:", global_step)

2️⃣ 记录学习率和损失
在 resume 的前 100 步高频打印 lr 与 loss,若 lr 不连续或 loss 突升,优先怀疑未恢复状态与步数。

3️⃣ AMP 的数值检查
半精度训练中,恢复后打印 scaler.get_scale(),若恢复为默认初值且随即出现溢出/underflow,需要同步加载 scaler 状态。

代码修改

1️⃣ 保存检查点时同步写入所有训练态

state = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"scheduler": scheduler.state_dict() if scheduler else None,"scaler": scaler.state_dict() if scaler else None,"global_step": global_step,"epoch": epoch,
}
torch.save(state, ckpt_path)

2️⃣ 恢复时按照先构建后加载的顺序加载全部状态

state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
if optimizer and state.get("optimizer"): optimizer.load_state_dict(state["optimizer"])
if scheduler and state.get("scheduler"): scheduler.load_state_dict(state["scheduler"])
if scaler and state.get("scaler"): scaler.load_state_dict(state["scaler"])
global_step = state.get("global_step", 0)
epoch = state.get("epoch", 0)

3️⃣ 训练循环里以 global_step 为唯一驱动
将日志、评估、保存、调度等触发条件统一用 global_step,避免 resume 后 epoch 边界重复或跳过。

if (global_step % log_every) == 0: ...
if (global_step % eval_every) == 0: ...
if (global_step % save_every) == 0: ...

Q & A

  • OneCycleLR 或 CosineAnnealingLR 的 total_steps/T_max 该怎么设

    建议基于优化步数而非 batch 数或 epoch 数计算。若启用梯度累积,应使用 (len(dataloader) // accum_steps) × epochs。resume 时保持 total_steps 不变,并恢复 scheduler 的内部步数。

  • 梯度累积是否影响调度步进

    若采用每步调度,应在完成一次优化步时再 step 调度器,且 resume 后 last_epoch 与累计的优化步对齐。

  • 断点落点在 epoch 中间怎么办

    强烈建议用 global_step 做所有触发条件,resume 后自然对齐;若用 ep och 边界,请保存 last_batch_idx 并在恢复时跳过已完成的 batch。

  • 分布式训练如何保存与恢复

    DDP/FSDP 下通常在 rank0 保存,恢复时先构建并 wra p 模型,再加载 state_dict。FSDP 建议使用库提供的全局一致性 checkpoint 接口。

结语

断点续训不是“从当前 loss 继续”,而是“从当前优化动力学继续”。只加载模型权重相当于丢掉了动量、学习率位置与半精度缩放的全部历史信息,难免出现曲线回退与不稳定。把优化器、调度器、GradScaler 与 global_step 一并纳入检查点模板,并在恢复时做一次完整的自检,续训就能与从头训练保持一致的轨迹与表现。

http://www.dtcms.com/a/388826.html

相关文章:

  • 滑动窗口法的优化与实战——力扣209.长度最小的子数组
  • 【Spring Boot 报错已解决】org.yaml.snakeyaml.scanner.ScannerException 报错原因与解决方案
  • 国家统计局数据读取——数据读取——清洗数据06
  • 基于 scratch 构建简单镜像
  • Web安全的暗角:10大易忽略逻辑漏洞解析!
  • 矩阵奇异值分解算法(SVD)详解
  • 【FreeRTOS】 二值信号量与互斥量(CMSIS-RTOS v2 版本)
  • Qt C++ :Qt全局定义<QtGlobal>
  • 【STL源码剖析】从源码看 list:从迭代器到算法
  • MySQL 专题(三):事务与锁机制深度解析
  • 使用BLIP训练自己的数据集(图文描述)
  • Geoserver修行记--在geoserver中如何复制某个图层组内容
  • DBG数据库透明加密网关:SQLServer应用免改造的安全防护方案,不限制开发语言的加密网关
  • 不同上位开发语言、PLC下位平台、工业协议与操作系统平台下的数据类型通用性与差异性详解
  • 【入门篇|第二篇】从零实现选择、冒泡、插入排序(含对数器)
  • javaweb Servlet基本介绍及开发流程
  • MySQL MHA高可用
  • 整体设计 逻辑拆解之2 实现骨架:一元谓词+ CNN的谓词系统
  • SpEL(Spring Expression Language)学习笔记
  • Java 字节码进阶3:面向对象多态在字节码层面的原理?
  • Tensor :核心概念、常用函数与避坑指南
  • 机器学习实战·第四章 训练模型(1)
  • 一次因表单默认提交导致的白屏排查记录
  • Linux:io_uring
  • 《第九课——C语言判断:从Java的“文明裁决“到C的“原始决斗“——if/else的生死擂台与switch的轮盘赌局》
  • 学习日报|Spring 全局异常与自定义异常拦截器执行顺序问题及解决
  • Spring Boot 参数处理
  • Debian系统基本介绍:新手入门指南
  • Spring Security 框架
  • Qt QPercentBarSeries详解