当前位置: 首页 > news >正文

显存优化:梯度检查点Gradient Checkpoint和梯度累积Gradient Accumulation

梯度检查点(Gradient Checkpointing) 和 梯度累积(Gradient Accumulation) 是两种不同的显存优化技术,虽然目标类似(减少训练时的显存占用),但实现原理完全不同。以下是详细解释:


1. 梯度检查点(Gradient Checkpointing)

原理
  • 核心思想:用时间换空间,通过选择性丢弃中间激活值,在反向传播时重新计算它们,从而减少显存占用。

  • 工作流程

    1. 前向传播:只保存部分关键层的激活值(checkpoints),其余中间结果被丢弃。

    2. 反向传播:根据保存的检查点,重新计算被丢弃的中间结果(额外计算开销)。

  • 显存节省:显存占用可减少到原来的 1/√N(N 为模型层数),但会增加约 30% 的计算时间。

适用场景
  • 大模型训练(如扩散模型、LLM),显存不足但计算资源充足时。

  • 在代码中通常通过 torch.utils.checkpoint 或库(如 HuggingFace diffusers)的 enable_gradient_checkpointing() 启用。

from torch.utils.checkpoint import checkpoint# 前向传播时启用检查点
def forward_with_checkpoint(x):return checkpoint(custom_forward, x)  # custom_forward 是自定义的前向函数

2. 梯度累积(Gradient Accumulation)

原理
  • 核心思想:将一个大 batch 拆分成多个小 batch,累积多个小 batch 的梯度后再更新参数,模拟大 batch 的效果。

  • 工作流程

    1. 对小 batch 计算梯度,但不立即更新参数(optimizer.step())。

    2. 累积多次梯度后,统一更新参数。

  • 显存节省:显存占用与小 batch 相同,但训练时间更长(因需多次前向/反向)。

适用场景
  • 需要大 batch 但显存不足时(如目标检测、大语言模型微调)。

  • 在训练脚本中通过 accumulation_steps 参数控制。

optimizer.zero_grad()
for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()  # 梯度累积,不立即清零if (i + 1) % accumulation_steps == 0:optimizer.step()  # 累积足够步数后更新参数optimizer.zero_grad()

关键区别

特性梯度检查点(Checkpointing)梯度累积(Accumulation)
目标减少激活值显存占用模拟大 batch 训练
显存节省来源丢弃并重算中间结果使用小 batch 多次累积梯度
计算开销增加反向传播计算量(时间换空间)增加训练步数(时间换 batch 大小)
代码实现torch.utils.checkpointaccumulation_steps 参数

在扩散模型中的实际应用

  • 梯度检查点
    在训练 Stable Diffusion 等大模型时,启用 enable_gradient_checkpointing() 可将显存从 24GB 降至 16GB 左右,但训练速度会变慢。

  • 梯度累积
    若想用更大的 batch size(如提升训练稳定性),可通过累积梯度实现,但不会减少单步显存占用。

建议根据硬件条件组合使用两者(如同时启用检查点和累积梯度),以平衡显存和训练效率。

相关文章:

  • 【嵌入式DIY实例-Arduino篇】-DIY遥控手柄
  • Java SpringMVC与MyBatis整合
  • Grafana v12.0 引入了多项新功能和改进
  • Docker 部署Nexus仓库 搭建Maven私服仓库 公司内部仓库
  • 软考第五章知识点总结
  • [Java实战]Spring Boot 整合 Freemarker (十一)
  • RAGMCP基本原理说明和相关问题解惑
  • 1.5 提示词工程(一)
  • USB学习【6】USB传输错误的处理
  • 基于去中心化与AI智能服务的web3钱包的应用开发的背景描述
  • 湖北理元理律师事务所债务优化体系拆解:科学规划如何实现“还款不降质”
  • [ERTS2012] 航天器星载软件形式化模型驱动研发 —— 对 Scade 语言本身的影响
  • 使用 Java 反射动态加载和操作类
  • 【前端】【HTML】【总复习】一万六千字详解HTML 知识体系
  • 事务(理解)与数据库连接池
  • 【AI论文】作为评判者的感知代理:评估大型语言模型中的高阶社会认知
  • 【Java学习笔记】instanceof操作符
  • Quantum convolutional nerual network
  • Web开发—Vue工程化
  • stm32实战项目:无刷驱动
  • 睡觉总做梦是睡眠质量差?梦到这些事,才要小心
  • 上海与世界|环城生态公园带是上海绿色发展新名片
  • 阶跃星辰CEO姜大昕:追求智能上限仍是最重要的事,多模态的“GPT-4时刻”尚未到来
  • 高培勇:中国资本市场的发展应将预期因素全面纳入分析和监测体系
  • 匈牙利外长称匈方已驱逐两名乌克兰外交官
  • 习近平向中国人民解放军仪仗队致意