当前位置: 首页 > news >正文

深度学习优化-Gradient Checkpointing

数学原理参考:

梯度检查点技术(Gradient Checkpointing)详细介绍:中英双语-CSDN博客

视频讲解参考:

用梯度检查点来节省显存 gradient checkpointing_哔哩哔哩_bilibili

Gradient Checkpointing(梯度检查点

Gradient Checkpointing 是一种用于优化深度学习模型训练的技术,旨在减少训练过程中显存的占用。在深度神经网络训练中,通常需要存储每一层的激活值以用于反向传播计算梯度。然而,对于层数较多或参数量较大的模型,这些激活值会占用大量显存。

Gradient Checkpointing 的核心思想是在前向传播时选择性地保存部分激活值(称为检查点),而丢弃其他激活值。在反向传播时,如果需要这些被丢弃的激活值,则重新计算它们。通过这种方式,显存使用量可以从 O(L) 降低到 O(K),其中 L 是网络层数,K 是选择的检查点层数。

工作原理

  1. 选择检查点:在前向传播时,选择某些层作为检查点,保存这些层的激活值。

  2. 丢弃激活值:对于未被选为检查点的层,丢弃其激活值。

  3. 反向传播时重新计算:在反向传播时,如果需要被丢弃的激活值,则通过重新计算它们来获取,从而计算梯度。

a1和a3被丢弃,反向传播时,如果需要被丢弃的激活值,则需要重新计算

a1 = x * w1,

a3 = a2 * w3

优点与缺点

优点

  • 显著减少显存占用,使训练更大规模的模型成为可能。

  • 在显存受限的环境中,可以提高训练效率。

  • 允许使用更大的批量大小,从而加速训练。

缺点

  • 增加了计算开销,因为需要在反向传播时重新计算激活值。

  • 实现复杂度增加,需要修改代码来管理检查点。

  • 可能导致训练时间延长。

实现方法

在 PyTorch 中,可以通过 torch.utils.checkpoint 模块实现 Gradient Checkpointing。例如:

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(256, 256)
        self.layer2 = nn.Linear(256, 256)
        self.layer3 = nn.Linear(256, 10)

    def forward(self, x):
        x = checkpoint.checkpoint(self.layer1, x)  # 应用梯度检查点
        x = checkpoint.checkpoint(self.layer2, x)
        x = self.layer3(x)  # 最后一层不需要检查点
        return x

在 DeepSpeed 中,可以通过配置文件启用 Gradient Checkpointing:

{
    "train_batch_size": 16,
    "gradient_accumulation_steps": 4,
    "zero_optimization": {
        "stage": 2,
        "contiguous_gradients": true
    },
    "gradient_checkpointing": true
}

应用场景

Gradient Checkpointing 广泛应用于以下场景:

  • 训练大规模深度学习模型,如 7B 或 10B 参数的模型。

  • 在 GPU 显存有限的环境中优化训练。

  • 提高训练效率,同时减少硬件成本。

通过合理使用 Gradient Checkpointing,可以在有限的硬件资源下训练更大规模的模型,同时平衡显存和计算开销。

相关文章:

  • 华为欧拉系统安装redis官网最新版
  • 【视频】ffmpeg、Nginx搭建RTMP、HLS服务器
  • 文件解析漏洞靶场集锦详解
  • 段错误解析
  • Java 实现 Android ViewPager2 顶部导航:动态配置与高效加载指南
  • 深度剖析 Doris 数据倾斜,优化方案一网打尽
  • Docker Compose 之详解(Detailed Explanation of Docker Compose)
  • spring中将yaml文件转换为Properties
  • Nginx 多协议代理功能(Nginx Multi Protocol Proxy Function)
  • PyQt基础——简单的窗口化界面搭建以及槽函数跳转
  • 在 LaTeX 中强制表格位于页面顶部
  • CCF-CSP第34次认证第四题——货物调度【DP+剪枝】
  • 红黑树介绍
  • Matplotlib高阶技术全景解析
  • 《阿里云Data+AI:开启数据智能新时代》电子书上线啦!
  • 操作系统学不会?————一文速通(FCFS,SJF/SPF,SRTN,HRRN算法)保姆级解析
  • 4.数据存储**
  • Attention又升级!Moonshot | 提出MoE注意力架构:MoBA,提升LLM长文本推理效率
  • Python爬虫实战:基于 Scrapy 框架的腾讯视频数据采集研究
  • 蓝桥-数字接龙
  • 旅游响应式网站建设/百度文库首页
  • 成都个人网站制作公司/网络游戏推广平台
  • 创意网站案例/发布信息的免费平台
  • 电脑搭建网站需要空间/微信公众号的推广
  • 如何推广自己网站链接/萧山seo
  • 幼儿园管理网站模板下载/万网官网