当前位置: 首页 > news >正文

Pytorch Lightning 进阶 1 - 梯度检查点(Gradient Checkpointing)

梯度检查点(Gradient Checkpointing)是一种在深度学习训练中优化显存使用的技术,尤其适用于处理大型模型(如Transformer架构)时显存不足的情况。下面用简单的例子解释其工作原理和优缺点:

核心原理

深度学习训练中的显存占用主要来自三个方面:

  1. 模型参数(如权重、偏置)
  2. 优化器状态(如Adam的动量项)
  3. 中间激活值(forward过程中产生的张量,如注意力图、隐藏层输出等)

其中,中间激活值通常占用最大的显存空间,尤其是在深层网络中。梯度检查点的核心思想是:

  • 在正向传播时:只保存少量关键的中间结果(称为“检查点”),其余中间值在计算后立即丢弃。
  • 在反向传播时:利用保存的检查点重新计算被丢弃的中间值,从而获得计算梯度所需的全部信息。

如下图所示:
在这里插入图片描述

这种方法通过牺牲计算时间(重新计算)来节省显存空间(无需保存所有中间值)。

为什么需要梯度检查点?

假设你有一个包含100层的Transformer模型,每层在forward过程中产生1GB的中间激活值:

  • 传统训练:需要保存所有100层的中间值,总显存需求为100GB。
  • 梯度检查点:只保存10个检查点(每层1GB),反向传播时通过检查点重新计算其余90层,总显存需求降至10GB。

代码中的应用

在你的代码中,gradient_checkpointing=True的配置会使模型在训练时启用梯度检查点:

trainable_model = Model(# ...其他参数gradient_checkpointing=training_config.get("gradient_checkpointing", False),# ...
)

这意味着:

  1. 正向传播时,模型不会保存所有注意力图隐藏层输出
  2. 反向传播时,PyTorch会利用检查点重新计算这些值,从而减少显存占用

优缺点

  • 优点:显著减少显存使用(通常能节省30%-50%的显存),允许训练更大的模型或使用更大的批次大小。
  • 缺点:增加训练时间(通常慢20%-30%),因为需要重新计算中间值。

何时使用?

  • 显存不足:当模型因显存限制无法训练时,梯度检查点是一种有效的解决方案。
  • 计算资源充足:如果你的GPU算力充足但显存有限,可以通过延长训练时间换取更小的显存占用。

技术细节

在PyTorch中,梯度检查点通过torch.utils.checkpoint模块实现。例如:

from torch.utils.checkpoint import checkpointdef forward(self, x):# 普通forward:保存所有中间值x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 使用梯度检查点:只保存关键检查点
def forward(self, x):x = checkpoint(self.layer1, x)  # 只保存layer1的输出x = checkpoint(self.layer2, x)  # 只保存layer2的输出x = self.layer3(x)return x

PyTorch Lightning的gradient_checkpointing参数会自动为模型的所有层应用这种优化。

相关文章:

  • MySQL8:jdbc插入数据后获取自增ID
  • 实现Markdown文本转html并使用html2canvas导出图片
  • 可信计算的基石:TPM技术深度解析与应用实践
  • 图像融合中损失函数【1】--像素级别损失
  • 如何快速判断Excel文档是否被修改过?Excel多版本比对解决方案
  • 新能源知识库(65)逆变器和PCS的专用散热风扇介绍
  • Java学习第一周
  • Hum Brain Mapp.:从深度学习模型回归大脑:揭示区域预测因子及其与衰老的关系
  • QT6(46)5.2 QStringListModel 和 QListView :列表的模型与视图的界面搭建与源代码实现
  • Gartner《Generative AI Use - Case Comparison for Legal Departments》
  • python基于微信小程序的广西文化传承系统
  • 智慧水利新引擎,数字孪生流域解决方案
  • 生成式AI与智能体改写互联网、IT与工业经济格局
  • 深度学习:PyTorch卷积神经网络(CNN)之图像入门
  • 【Leetcode】有效的括号、用栈实现队列、用队列实现栈
  • 成都芯谷金融中心文化科技产业园:构建文化科技产业融合新标杆
  • MySQL 8.x配置MGR高可用+ProxySQL读写分离(二):ProxySQL配置MySQL代理及读写分离
  • 【GoLang】3、基于虚拟头尾节点快速实现双向链表
  • 计算Transformer的Flops
  • 从 0 到 1 打造社区产品:短说社区助力开启社交新篇
  • 网站开发哪家公司电话/中山seo关键词
  • 怎么做电视台网站/网络推广的优势
  • 用什么软件来建网站/网络营销pdf
  • 网站系统制作教程/网络营销网络推广
  • 福田网站设计哪家好/2020国内搜索引擎排行榜
  • 网络基础知识大全/佛山网络公司 乐云seo