显存优化:梯度检查点Gradient Checkpoint和梯度累积Gradient Accumulation
梯度检查点(Gradient Checkpointing) 和 梯度累积(Gradient Accumulation) 是两种不同的显存优化技术,虽然目标类似(减少训练时的显存占用),但实现原理完全不同。以下是详细解释:
1. 梯度检查点(Gradient Checkpointing)
原理
-
核心思想:用时间换空间,通过选择性丢弃中间激活值,在反向传播时重新计算它们,从而减少显存占用。
-
工作流程:
-
前向传播:只保存部分关键层的激活值(checkpoints),其余中间结果被丢弃。
-
反向传播:根据保存的检查点,重新计算被丢弃的中间结果(额外计算开销)。
-
-
显存节省:显存占用可减少到原来的 1/√N(N 为模型层数),但会增加约 30% 的计算时间。
适用场景
-
大模型训练(如扩散模型、LLM),显存不足但计算资源充足时。
-
在代码中通常通过
torch.utils.checkpoint
或库(如 HuggingFacediffusers
)的enable_gradient_checkpointing()
启用。
from torch.utils.checkpoint import checkpoint# 前向传播时启用检查点
def forward_with_checkpoint(x):return checkpoint(custom_forward, x) # custom_forward 是自定义的前向函数
2. 梯度累积(Gradient Accumulation)
原理
-
核心思想:将一个大 batch 拆分成多个小 batch,累积多个小 batch 的梯度后再更新参数,模拟大 batch 的效果。
-
工作流程:
-
对小 batch 计算梯度,但不立即更新参数(
optimizer.step()
)。 -
累积多次梯度后,统一更新参数。
-
-
显存节省:显存占用与小 batch 相同,但训练时间更长(因需多次前向/反向)。
适用场景
-
需要大 batch 但显存不足时(如目标检测、大语言模型微调)。
-
在训练脚本中通过
accumulation_steps
参数控制。
optimizer.zero_grad()
for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)loss.backward() # 梯度累积,不立即清零if (i + 1) % accumulation_steps == 0:optimizer.step() # 累积足够步数后更新参数optimizer.zero_grad()
关键区别
特性 | 梯度检查点(Checkpointing) | 梯度累积(Accumulation) |
---|---|---|
目标 | 减少激活值显存占用 | 模拟大 batch 训练 |
显存节省来源 | 丢弃并重算中间结果 | 使用小 batch 多次累积梯度 |
计算开销 | 增加反向传播计算量(时间换空间) | 增加训练步数(时间换 batch 大小) |
代码实现 | torch.utils.checkpoint | accumulation_steps 参数 |
在扩散模型中的实际应用
-
梯度检查点:
在训练 Stable Diffusion 等大模型时,启用enable_gradient_checkpointing()
可将显存从 24GB 降至 16GB 左右,但训练速度会变慢。 -
梯度累积:
若想用更大的 batch size(如提升训练稳定性),可通过累积梯度实现,但不会减少单步显存占用。
建议根据硬件条件组合使用两者(如同时启用检查点和累积梯度),以平衡显存和训练效率。