当前位置: 首页 > news >正文

【大模型LLM】梯度累积(Gradient Accumulation)原理详解

在这里插入图片描述

梯度累积(Gradient Accumulation)原理详解

梯度累积是一种在深度学习训练中常用的技术,特别适用于显存有限但希望使用较大批量大小(batch size)的情况。通过梯度累积,可以在不增加单个批次大小的情况下模拟较大的批量大小,从而提高模型的稳定性和收敛速度。

基本概念

在标准的随机梯度下降(SGD)及其变体(如Adam、RMSprop等)中,每次更新模型参数时都需要计算整个批次数据的损失函数梯度,并立即用这个梯度来更新模型参数。然而,在处理大规模数据集或使用非常大的模型时,单个批次的数据量可能会超出GPU显存的容量。此时,梯度累积技术就可以发挥作用。

工作原理

梯度累积的核心思想是:将多个小批次(mini-batch)的梯度累加起来,然后一次性执行一次参数更新。具体步骤如下:

  1. 初始化梯度累积器:在每个训练步骤开始时,初始化一个梯度累积器(通常为零)。
  2. 前向传播与梯度计算
    • 对于每一个小批次 i(从 1 到 k),执行前向传播计算损失。
    • 执行反向传播计算该小批次的梯度。
  3. 累积梯度:将当前小批次的梯度累加到梯度累积器中。
  4. 参数更新:当累积了 k 个小批次的梯度后,使用累积的梯度来更新模型参数,并重置梯度累积器。
详细步骤

假设我们希望使用的批量大小是 N,但由于显存限制只能使用较小的批量大小 n(其中 N = k * n),那么我们可以进行 k 次前向和后向传播,每次都计算一个小批次的梯度并将其累加,直到累积了 k 个小批次的梯度之后,再进行一次参数更新。

示例代码

以下是一个简单的PyTorch示例,展示了如何实现梯度累积:

import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
model = nn.Linear(10, 2)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 设置梯度累积步数
accumulation_steps = 4
optimizer.zero_grad()  # 清空梯度for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)# 将损失除以累积步数,使得总的损失不变loss = loss / accumulation_steps# 反向传播计算梯度loss.backward()if (i + 1) % accumulation_steps == 0:# 累积足够步数后,执行优化步骤optimizer.step()optimizer.zero_grad()  # 清空梯度
关键点解释
  1. 损失缩放:由于我们将一个大批次分成多个小批次,并且每次只计算一个小批次的损失,因此需要将每个小批次的损失除以累积步数 accumulation_steps,以确保总的损失值保持不变。

  2. 梯度累积:每次反向传播后,梯度会被累加而不是立即用于更新参数。只有当累积了足够的步数后,才会使用累积的梯度进行一次参数更新。

  3. 参数更新:在累积了足够的梯度后,调用 optimizer.step() 来更新模型参数,并清空梯度累积器(即调用 optimizer.zero_grad())。

优点
  • 突破显存限制:通过使用较小的批量大小,可以有效地减少每一步所需的显存量,从而允许在有限的硬件资源上训练更大的模型或使用更大的批量大小。
  • 模拟大批次训练效果:梯度累积实际上模拟了使用较大批量大小的效果,有助于提高模型训练的稳定性和收敛速度。
  • 灵活性:可以根据实际硬件条件灵活调整累积步数,适应不同的训练需求。
注意事项
  • 学习率调整:由于梯度累积实际上是将多个小批次的梯度累加起来进行一次更新,因此需要相应地调整学习率。例如,如果原始设置的学习率为 lr,并且使用了 k 步梯度累积,则新的有效学习率应为 lr * k
  • 随机性影响:梯度累积可能会引入一定的随机性,因为不同小批次之间的顺序可能会影响最终的梯度累积结果。不过,在实践中这种影响通常是可以接受的。
总结

梯度累积是一种非常实用的技术,特别是在显存受限但希望利用更大批量大小的情况下。它不仅帮助克服了硬件限制,还能够保持甚至提升模型训练的质量。通过合理配置梯度累积步数和学习率,可以显著改善训练效率和效果。

http://www.dtcms.com/a/302585.html

相关文章:

  • linux I2C设备AW2013驱动示例
  • rhel网卡配置文件、网络常用命令、网卡名称优化和模拟不同网络区域通信
  • 服务器中的防火墙设置需要打开吗
  • 服务器查日志太慢,试试grep组合拳
  • 利用frp实现内网穿透功能(服务器)Linux、(内网)Windows
  • CentOS7 安装和配置教程
  • RF随机森林分类预测+特征贡献SHAP分析,通过特征贡献分析增强模型透明度,Matlab代码实现,引入SHAP方法打破黑箱限制,提供全局及局部双重解释视角
  • 论文:M矩阵
  • 高可用集群Keepalived、Redis、NoSQL数据库Redis基础管理
  • 常用设计模式系列(十四)—模板方法模式
  • 在 CentOS 上安装 FFmpeg
  • 行业案例:杰和科技为智慧教育构建数字化硬件底座
  • UML类图--基于大话设计模式
  • 【设计模式】状态模式 (状态对象(Objects for States))
  • NBIOT模块 BC28通过MQTT协议连接到电信云
  • Google Chrome V8< 13.7.120 沙箱绕过漏洞
  • 设计模式(二十三)行为型:模板方法模式详解
  • 从 “看天吃饭” 到 “精准可控”:边缘计算网关如何引爆智慧农业种植变革?
  • 新手向:破解VMware迁移难题
  • 解放io_uring编程:liburing实战指南与经典cat示例解析
  • Unity_UI_NGUI_组合控件2
  • Rust实战:AI与机器学习自动炒饭机器学习
  • puppeteer 系列模块的系统性、详细讲解
  • Ubuntu系统完整配置教程
  • InfluxDB 与 HTTP 协议交互进阶(一)
  • 设计模式实战:自定义SpringIOC(理论分析)
  • 无界设计新生态:Penpot开源平台与cpolar的云端协同创新实践
  • 第二十二节 MATLAB转置向量、MATLAB追加向量
  • C++---初始化列表(initializer_list)
  • 基于黑马教程——微服务架构解析(二):雪崩防护+分布式事务