当前位置: 首页 > news >正文

等效学习率翻倍?梯度累积三连坑:未除以 accum_steps、调度器步进错位、梯度裁剪/正则标度错误(含可复现实验与修复模板)

等效学习率翻倍?梯度累积三连坑:未除以 accum_steps、调度器步进错位、梯度裁剪/正则标度错误(含可复现实验与修复模板)

在我们进行模型微调的时候,显存的限制促使我们要调小batch_size以及降低模型复杂度。还有一个办法就是用梯度累积把每步的 batch 拆成多个小 micro-batch。看起来 loss 会降,但曲线比不累积还抖,越训越不稳;同样的总 batch 和学习率,指标却明显更差。笔者总结了可能出现的现象以及debug过程供大家参考学习。


❓ Bug 现象

  • 相同的总 batch 与学习率,开启梯度累积后 loss 波动变大,收敛更慢。
  • 同样跑完 N 个优化步,出现 learning rate 步进次数翻倍或与预期不一致。
  • 做梯度裁剪时,未分摊导致裁剪更频繁,更新幅度异常小(或爆炸)。
  • 用 Adam(weight_decay=λ) 而不是 AdamW 时,忘了分摊损失里的 L2 项,正则强度等效放大。

📽️ 场景复现

保存为 accum_lr_debug.py,CPU 即可跑通。

# accum_lr_debug.py
import argparse, math, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)def make_loader(n=8192, bs=64):X = torch.randn(n, 20)y = (X[:, 0] + 0.6*X[:, 1] - 0.3*X[:, 2] > 0).long()ds = torch.utils.data.TensorDataset(X, y)return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)class MLP(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(20,128), nn.ReLU(), nn.Linear(128,2))def forward(self, x): return self.net(x)def build_cosine(optimizer, total_steps, warmup=50):def lr_lambda(step):if step < warmup: return (step+1)/max(1,warmup)t = (step-warmup)/max(1,total_steps-warmup)return 0.5*(1+math.cos(math.pi*min(1.0,t)))return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)def run(accum_steps=8, bug=True, total_updates=300):device = "cpu"model = MLP().to(device)# 故意用 Adam + L2 展示“耦合正则”问题;修复版会换 AdamWif bug:optimizer = torch.optim.Adam(model.parameters(), lr=3e-3, weight_decay=1e-2)else:optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-2)scheduler = build_cosine(optimizer, total_updates)loader = make_loader()it = iter(loader)micro_bs = next(iter(loader))[0].shape[0]  # 只是打印用途optimizer.zero_grad(set_to_none=True)micro_count = 0for upd in range(1, total_updates+1):# 累积 accum_steps 个 micro-batchfor k in range(accum_steps):try: x, y = next(it)except StopIteration:it = iter(loader); x, y = next(it)out = model(x)loss = F.cross_entropy(out, y)if not bug:# 修复:分摊到每个 micro-batch(解耦正则用 AdamW,无需再分摊 L2)loss = loss / accum_steps# 梯度反传loss.backward()# 错误示范:在每个 micro-batch 都调度/裁剪if bug:nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 未分摊前裁剪,等效更严scheduler.step()  # 调度器按 micro-batch 步进micro_count += 1# 一次优化步if not bug:# 修复:先在“累积后的”梯度上裁剪,再 step;调度器只在优化步调用nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()optimizer.zero_grad(set_to_none=True)if not bug:scheduler.step()# 监控:有效学习率/更新幅度with torch.no_grad():gnorm = 0.0pn, up = 0.0, 0.0for p in model.parameters():if p.grad is not None:gnorm += p.grad.norm().item()pn += p.data.norm().item()lr = optimizer.param_groups[0]["lr"]if upd % 25 == 0:tag = "BUG" if bug else "FIX"print(f"[{tag}] upd={upd:03d} lr={lr:.4f} micro_bs={micro_bs} accum={accum_steps} gnorm≈{gnorm:.3f}")if __name__ == "__main__":ap = argparse.ArgumentParser()ap.add_argument("--accum", type=int, default=8)ap.add_argument("--bug", choices=["on","off"], default="on")args = ap.parse_args()print("== 错误设置 =="); run(accum_steps=args.accum, bug=True)print("\n== 正确设置 =="); run(accum_steps=args.accum, bug=False)

你会看到

  • 错误设置下,调度器每个 micro-batch 都 step,300 个优化步却走了 300×accum_steps 次学习率曲线;梯度裁剪触发更频繁。
  • 正确设置下,学习率随优化步平滑变化,gnorm 更稳定。

Debug 过程

1️⃣ 打印基线与等效学习率
把 loss 是否除以 accum_steps 写进日志,计算有效学习率 lr_eff ≈ lr(若没除,就相当于放大了 accum_steps 倍的梯度)。

2️⃣ 检查调度器步进频率
每次 optimizer.step() 后再 scheduler.step()。若你有 300 个优化步,总的 scheduler.last_epoch 应该接近 300,而不是 300×accum_steps。

3️⃣ 梯度裁剪触发率
统计被裁剪比例/次数。如果未分摊就对 micro-batch 的梯度裁剪,等效阈值更紧,更新过小。

4️⃣正则项位置
用 Adam(L2) 时把 L2 加入损失,需要按 accum_steps 分摊;更稳妥的做法是换 AdamW(解耦权重衰减),避免 L2 与自适应缩放耦合。

代码修改

1️⃣ 分摊损失、只在优化步调度

loss = criterion(out, y) / accum_steps
loss.backward()
if (micro_idx + 1) % accum_steps == 0:# 累积完成后,再裁剪、再 steptorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)optimizer.step()optimizer.zero_grad(set_to_none=True)scheduler.step()  # 只在优化步调用

2️⃣ 正确的权重衰减与 L2
优先使用 AdamW(weight_decay=λ)。如果历史原因必须使用 Adam+L2,把 L2 作为分摊后的附加项:

# 不推荐,但若必须:
loss = ce(out, y) / accum_steps + (l2_lambda / accum_steps) * sum((p**2).sum() for p in params_to_regularize)

3️⃣ AMP 下的梯度累积顺序,每个 micro-batch 都 scale 后 backward;累积完成后 unscale_ 再裁剪、再 step:

scaler = torch.cuda.amp.GradScaler()
optimizer.zero_grad(set_to_none=True)
for micro in range(accum_steps):with torch.cuda.amp.autocast():loss = criterion(model(xb), yb) / accum_stepsscaler.scale(loss).backward()# 累积结束
scaler.unscale_(optimizer)  # 把缩放还原到真实梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()

4️⃣ 与分布式的交互,在 DDP/FSDP 下,梯度累积不改变梯度同步时机(默认每次 backward都会同步)。如需减少通信,使用 no_sync 仅在累积的前 accum_steps-1 个 micro-batch 上关闭同步,最后一个再开启。

for i, (x, y) in enumerate(loader):ctx = model.no_sync() if (i % accum_steps != accum_steps-1) else nullcontext()with ctx:loss = criterion(model(x), y) / accum_stepsloss.backward()if (i % accum_steps == accum_steps-1):torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)optimizer.step(); optimizer.zero_grad(); scheduler.step()

Q & A

  • 梯度累积会不会改变 BatchNorm 统计吗
    ✔️ 会。BN 只看当前 micro-batch 的统计,累积不能增大 BN 的有效 batch。小批建议用 SyncBatchNorm(多卡)或改 GroupNorm/LayerNorm。
  • OneCycleLR 该按什么步进
    ✔️ 按优化步进而不是 micro-batch。确定 total_steps 应基于 每轮优化步 = floor(len(loader)/accum_steps)。
  • 梯度裁剪应该在什么时候做
    ✔️ 在累积结束、step 前做;AMP 下先 unscale_ 再裁剪。
  • 两阶段/分组学习率如何处理
    ✔️分摊只作用在 loss。学习率调度仍以优化步为单位,按各 param group 正常工作。

结语

梯度累积不是“白嫖大 batch”,它对损失标度、调度步进、裁剪与正则都有连锁影响。把 loss/accum_steps 写死在模板里,把调度器与裁剪绑定到优化步,并用 AdamW 解耦权重衰减,会立刻得到与不累积相同的等效学习率与更稳定的曲线。复盘后发现有三个常见坑:没有把 loss 除以 accum_steps;学习率调度器按 micro-batch 步进;梯度裁剪和 L2 正则的标度没有随累积分摊。

http://www.dtcms.com/a/390979.html

相关文章:

  • 嵌入式学习笔记(44)IMX6ULL
  • OpenStack 学习笔记(五):网络管理和虚拟网络实践与存储管理实验(下)
  • 博睿数据携手华为共筑智能未来,深度参与HUAWEI CONNECT 2025并发表主题演讲
  • 陈童理论物理新讲1 哈密顿力学初步
  • 9.19 Sass
  • 设计模式详解:单例模式、工厂方法模式、抽象工厂模式
  • 终端同居物语:Shell咏唱术式与权限结界の完全解析书
  • XeLaTeX 中文删除线自动换行问题的解决方案
  • R语言中的因子(Factor)详解 factor_path <- as.factor(char_path)
  • 软件测试之⾃动化测试常⽤函数(沉淀中)
  • 火山引擎多模态数据湖:基于 Daft 与 Lance,构筑 AI 时代数据湖新范式
  • 关于强化学习的一篇经典学习文章
  • 【JavaScript 性能优化实战】第四篇:webpack 与 vite 打包优化实战
  • maven-profile实现springboot多版本配置打包
  • OpenLayers地图交互 -- 章节二:绘制交互详解:从基础几何到复杂图形的完整绘制方案
  • Java 工厂模式 + 策略模式实战:工具管理器的设计与实现
  • 污水处理厂远程调试与智能化运维解决方案
  • 【提示工程】Ch2-提示技术(Prompt Technique)
  • vLLM - Worker
  • GitHub上面仓库名写错了,怎么改
  • 项目中的图形验证码是前端还是后端实现?
  • ✅ 基于Scrapy与朴素贝叶斯的校园舆情监测与预警系统 Django+B/S架构 可视化大屏 机器学习
  • Unity UI 插件 | Easy Popup System
  • AI证件照制作 API 快速生成证件照
  • @RequestParam和 @RequestBody能一起用吗
  • 构建高效的电商爬虫代理池:从架构设计到实战优化
  • 使用cJSON库实现JSON与C结构体的互转
  • Cursor :Python 运行路径设置自定义模块导入报错:No module named ‘xxx’ 的解决方案
  • 数图信息科技亮相唐山社区零售论坛,数字化赋能行业高质量发展
  • LLM大模型 - 实战篇 - Assistant API 原理与实战应用