当前位置: 首页 > news >正文

浅谈ai - Activation Checkpointing - 时间换空间

前言

曾在游戏世界挥洒创意,也曾在前端和后端的浪潮间穿梭,如今,而立的我仰望AI的璀璨星空,心潮澎湃,步履不停!愿你我皆乘风破浪,逐梦星辰!

Activation Checkpointing(激活检查点),在 DeepSpeed 里是一个非常实用的显存优化技术。简单来说,它的核心目标就是:

“节省显存,代价是多做点计算。”

下面我给你讲明白它是啥、怎么干的、为什么有效、啥时候该用。


🧠 一句话解释

Activation Checkpointing 就是:

在前向传播时不保存所有中间激活(activation),而是只保存“关键点”的激活;
等到反向传播需要时,再重新计算中间丢掉的部分


📦 举个例子(以 Transformer 为例)

假设你有一个 12 层的 Transformer:

  • 正常训练时,前向传播完 12 层,会保存 每一层的激活输出(用于反向传播时算梯度)。

  • 这些激活可能占你一大半的显存!

💡 但是,如果你设置了 activation checkpointing:

  • 你只保存第 0、4、8、12 层的激活(比如每隔 4 层存一个“检查点”)。

  • 当反向传播需要第 2 层激活时,Deepspeed 会:

    • 重新从第 0 层开始计算到第 2 层。

    • 得到激活,然后继续正常反传。

  • 这样就省下了很多显存,但代价是计算时间会稍微变长(因为要“回头重算”)。


🔍 技术实现机制

步骤做了什么
前向传播只保存“检查点”处的激活,其他层的激活丢掉
反向传播当需要中间层激活时,重新从上一个检查点 forward 一遍
PyTorch 实现通常使用 torch.utils.checkpoint.checkpoint(function, *inputs) 来做
Deepspeed 支持可以自动为模型的某些模块启用 checkpoint,比如 TransformerBlock

🧮 显存节省 vs 计算开销

优点缺点
大幅减少激活存储显存(最多可省 50%+)增加前向传播的计算量(因为反向时要重算一部分)
可以训练更大的模型 / 更大 batch训练速度略慢(但通常接受)

典型使用场景

  • 显存紧张,训练不了大的 batch 或模型。

  • 训练 GPT/BERT 这类“块状重复”的模型时特别有效(容易切分 block)。


✅ 在 Deepspeed 中如何启用?

deepspeed_config.json 中这样配:

{
  "activation_checkpointing": {
    "partition_activations": true,
    "contiguous_memory_optimization": true,
    "cpu_checkpointing": false,
    "number_checkpoints": 4
  }
}
配置项含义
partition_activations按 rank 切分激活,进一步省显存
contiguous_memory_optimization内存优化,避免碎片化
cpu_checkpointing是否把 checkpoint 存 CPU(更慢,慎用)
number_checkpoints控制多少层之间插一个 checkpoint(类似“每几层存一次”)

📌 总结一句话

Activation Checkpointing 是一种“用时间换空间”的策略,省显存的同时代价是略微增加计算。对于大模型(如 GPT、BERT)训练来说是非常常见的标配技术。

如果你要训练 13B、30B 这种大模型或者 batch 太大撑不住,那这技术几乎是必开项了。

http://www.dtcms.com/a/113245.html

相关文章:

  • HANA如何在存储过程里执行动态SQL
  • 智慧节能双突破 强力巨彩谷亚VK系列刷新LED屏使用体验
  • 初识Linux-基本常用指令(一篇学会操作指令)
  • 03.unity开发资源 获取
  • 05.unity 游戏开发-3D工程的创建及使用方式和区别
  • Windows程序中计时器WM_TIMER消息的使用
  • Golang的Goroutine(协程)与runtime
  • 使用MATIO库读取Matlab数据文件中的稀疏矩阵
  • JAVA阻塞队列
  • OrangePi入门教程(待更新)
  • C++开发工具全景指南
  • 【java】在 Java 中,获取一个类的`Class`对象有多种方式
  • 6.5.图的基本操作
  • YOLOX 检测头以及后处理
  • 联网汽车陷入网络安全危机
  • 贪心算法之任务选择问题
  • mmap函数的概念和使用方案
  • 爬楼梯问题-动态规划
  • 3536 矩形总面积
  • leetcode4.寻找两个正序数组中的中位数
  • 类 和 对象 的介绍
  • 2024 .11-2025.3 一些新感悟
  • 【33期获取股票数据API接口】如何用Python、Java等五种主流语言实例演示获取股票行情api接口之沪深A股当天逐笔交易数据及接口API说明文档
  • 【2020】【论文笔记】相变材料与超表面——
  • 使用Cusor 生成 Figma UI 设计稿
  • 数据库并发控制问题
  • 麒麟系统桌面版本v10安装教程
  • 【动手学深度学习】卷积神经网络(CNN)入门
  • 低代码开发平台:飞帆画 echarts 柱状图
  • pygame里live2d的使用方法(live2d-py)