【Debug日志| 随机下降】
断点续训越来越差?未恢复优化器/调度器/GradScaler 状态导致的收敛倒退(含可复现实验与通用检查点模板)
在我们进行模型训练的过程中,可能会遇到这么一种情况:从头训练一切正常,但一旦中途断点续训,loss 开始抖、准确率掉、甚至直接发散。数据与代码未改,唯一不同是“加载了上次的模型权重继续训练”。本篇复盘可复现实验、定位方法与可直接落地的保存/恢复模板。
❓ Bug 现象
- 断点续训后,学习率曲线突然跳变(回到 warmup 高位或峰值附近)。
- 同等步数下,loss 明显高于从头训练,准确率短时回退。
- AMP 训练里偶发首批 NaN,或几步内 loss 急剧波动。
- 打印优化器状态字典发现为空,或者调度器 last_epoch 与预期不符。
📽️ 场景复现
保存为 resume_bug_demo.py,按注释运行两段即可在 CPU 复现。
import argparse, math, os, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)def make_loader(n=4096, bs=64):X = torch.randn(n, 10)y = (X[:, 0] + 0.6 * X[:, 1] > 0).long()ds = torch.utils.data.TensorDataset(X, y)return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)class TinyNet(nn.Module):def __init__(self):super().__init__()self.m = nn.Sequential(nn.Linear(10,64), nn.ReLU(), nn.Linear(64,2))def forward(self, x): return self.m(x)def build_scheduler(optimizer, total_steps, warmup_steps=50):def lr_lambda(step):if step < warmup_steps: return (step + 1) / warmup_stepst = (step - warmup_steps) / max(1, total_steps - warmup_steps)return 0.5 * (1 + math.cos(math.pi * min(1.0, t)))return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)def save_ckpt(path, model, optimizer, scheduler, scaler, global_step):os.makedirs(os.path.dirname(path), exist_ok=True)torch.save({"model": model.state_dict(),"optimizer": optimizer.state_dict(),"scheduler": scheduler.state_dict(),"scaler": scaler.state_dict() if scaler is not None else None,"global_step": global_step,}, path)def load_ckpt(path, model, optimizer=None, scheduler=None, scaler=None):ckpt = torch.load(path, map_location="cpu")model.load_state_dict(ckpt["model"])if optimizer is not None and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"])if scheduler is not None and "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"])if scaler is not None and ckpt.get("scaler") is not None: scaler.load_state_dict(ckpt["scaler"])return ckpt.get("global_step", 0)def train(phase, bug, steps):device = "cpu"model = TinyNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)scaler = None # GPU+AMP 时可替换为 GradScaler()total_steps = 400scheduler = build_scheduler(optimizer, total_steps)loader = make_loader()it = iter(loader)global_step = 0ckpt_path = "ckpts/demo.pt"if phase == "resume":if bug:_ = load_ckpt(ckpt_path, model) # 只加载模型,错误示范global_step = 0else:global_step = load_ckpt(ckpt_path, model, optimizer, scheduler, scaler)for _ in range(steps):try: x, y = next(it)except StopIteration:it = iter(loader); x, y = next(it)logits = model(x)loss = F.cross_entropy(logits, y)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()scheduler.step()global_step += 1if (global_step % 25) == 0:lr = optimizer.param_groups[0]["lr"]acc = (logits.argmax(1) == y).float().mean().item()tag = f"[{phase}|{'BUG' if bug else 'OK'}]"print(f"{tag} step={global_step:04d} lr={lr:.4f} loss={loss.item():.3f} acc={acc:.3f}")if phase == "pretrain":save_ckpt(ckpt_path, model, optimizer, scheduler, scaler, global_step)if __name__ == "__main__":ap = argparse.ArgumentParser()ap.add_argument("--phase", choices=["pretrain","resume"], required=True)ap.add_argument("--bug", choices=["on","off"], default="on")ap.add_argument("--steps", type=int, default=200)args = ap.parse_args()train(args.phase, args.bug=="on", args.steps)
运行方式
# 阶段1:从头训练并保存检查点
python resume_bug_demo.py --phase pretrain --steps 200# 阶段2a:错误续训(只加载模型权重)
python resume_bug_demo.py --phase resume --bug on --steps 200# 阶段2b:正确续训(加载模型+优化器+调度器+scaler+全局步)
python resume_bug_demo.py --phase resume --bug off --steps 200
后果
- 错误续训时 lr 会回到 warmup 峰值或高位,loss 短时恶化,acc 降低。
- 正确续训时 lr 连续,loss/acc 曲线平滑延续 pretrain 末尾的趋势。
- 若使用 AMP 且未恢复 GradScaler,首批更容易出现溢出或梯度无效。
Debug 过程
1️⃣ 检查优化器与调度器状态
打印优化器 state 是否为空、调度器 last_epoch 是否连续。
print("optimizer_has_state:", any(len(s)>0 for s in optimizer.state.values()))
print("scheduler_last_epoch:", getattr(scheduler, "last_epoch", None))
print("global_step:", global_step)
2️⃣ 记录学习率和损失
在 resume 的前 100 步高频打印 lr 与 loss,若 lr 不连续或 loss 突升,优先怀疑未恢复状态与步数。
3️⃣ AMP 的数值检查
半精度训练中,恢复后打印 scaler.get_scale(),若恢复为默认初值且随即出现溢出/underflow,需要同步加载 scaler 状态。
代码修改
1️⃣ 保存检查点时同步写入所有训练态
state = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"scheduler": scheduler.state_dict() if scheduler else None,"scaler": scaler.state_dict() if scaler else None,"global_step": global_step,"epoch": epoch,
}
torch.save(state, ckpt_path)
2️⃣ 恢复时按照先构建后加载的顺序加载全部状态
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
if optimizer and state.get("optimizer"): optimizer.load_state_dict(state["optimizer"])
if scheduler and state.get("scheduler"): scheduler.load_state_dict(state["scheduler"])
if scaler and state.get("scaler"): scaler.load_state_dict(state["scaler"])
global_step = state.get("global_step", 0)
epoch = state.get("epoch", 0)
3️⃣ 训练循环里以 global_step 为唯一驱动
将日志、评估、保存、调度等触发条件统一用 global_step,避免 resume 后 epoch 边界重复或跳过。
if (global_step % log_every) == 0: ...
if (global_step % eval_every) == 0: ...
if (global_step % save_every) == 0: ...
Q & A
-
OneCycleLR 或 CosineAnnealingLR 的 total_steps/T_max 该怎么设
建议基于优化步数而非 batch 数或 epoch 数计算。若启用梯度累积,应使用 (len(dataloader) // accum_steps) × epochs。resume 时保持 total_steps 不变,并恢复 scheduler 的内部步数。
-
梯度累积是否影响调度步进
若采用每步调度,应在完成一次优化步时再 step 调度器,且 resume 后 last_epoch 与累计的优化步对齐。
-
断点落点在 epoch 中间怎么办
强烈建议用 global_step 做所有触发条件,resume 后自然对齐;若用 ep och 边界,请保存 last_batch_idx 并在恢复时跳过已完成的 batch。
-
分布式训练如何保存与恢复
DDP/FSDP 下通常在 rank0 保存,恢复时先构建并 wra p 模型,再加载 state_dict。FSDP 建议使用库提供的全局一致性 checkpoint 接口。
结语
断点续训不是“从当前 loss 继续”,而是“从当前优化动力学继续”。只加载模型权重相当于丢掉了动量、学习率位置与半精度缩放的全部历史信息,难免出现曲线回退与不稳定。把优化器、调度器、GradScaler 与 global_step 一并纳入检查点模板,并在恢复时做一次完整的自检,续训就能与从头训练保持一致的轨迹与表现。