当前位置: 首页 > news >正文

gradient_accumulation_steps的含义

在微调大模型时,会有一个gradient_accumulation_steps参数

training_args = SFTConfig(# gradient_checkpointing=True,  # 启用梯度检查点以降低显存# gradient_checkpointing_kwargs={'use_reentrant': False},per_device_train_batch_size=4,learning_rate=3e-5,gradient_accumulation_steps=2,bf16=True,save_strategy='no',num_train_epochs=10,log_level='debug',output_dir="output",max_length=4096,  # 在这里设置序列长度
)

底层代码给的解释如下:

gradient_accumulation_steps (`int`, *optional*, defaults to 1):Number of updates steps to accumulate the gradients for, before performing a backward/update pass.<Tip warning={true}>When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.</Tip> 翻译成中文

gradient_accumulation_steps 的作用是在硬件如GPU内存有限、无法容纳较大批量(batch_size)时,模拟更大批量的训练效果。
详细的解释:
1、核心概念:随机梯度下降(SGD)与批量(Batch Size)
在训练深度学习模型时,我们通常不会一次使用整个数据集来计算梯度(这计算量太大),也不会一次只用一个样本(这噪音太大)。而是折中地使用一个小批量(Mini-batch) 的数据。
per_device_train_batch_size=4: 这个参数定义了物理批量大小(Physical Batch Size)。意思是,GPU每次最多只能同时处理并前向传播(forward)和反向传播(backward)4个样本。
在得到这4个样本的梯度后,优化器(如Adam)就会立即用这个梯度来更新模型权重。

2、 问题:小批量可能不稳定
有时候,per_device_train_batch_size 被迫设置得很小(比如这里的4),因为模型很大或者序列很长(max_length=4096),导致显存(VRAM)不够用。但是,太小的批量会带来两个问题:
1)训练不稳定: 基于仅仅4个样本计算出的梯度方向可能噪音很大,不能很好地代表整个数据集的真实梯度方向,导致训练过程震荡,难以收敛。
2)性能下降: 在某些任务上,使用较大的批量训练出的模型最终性能会更好。

3、解决方案:梯度累积(Gradient Accumulation)
gradient_accumulation_steps 就是为了解决上述问题而设计的。它让在不增加显存占用的情况下,模拟一个更大批量的训练效果。
gradient_accumulation_steps=2: 这个参数的意思是“累积2步”。

4、它是如何工作的呢?
结合per_device_train_batch_size=4, gradient_accumulation_steps=2来看:
第一步:模型正常处理第一批4个样本,进行前向传播和反向传播,计算出梯度1。但是优化器不会立即更新权重,而是将梯度1累积(加总)到一个缓冲区;
第二步:模型处理下一批4个样本,再次进行前向传播和反向传播,计算出梯度2。同样,这个梯度2也会被累积到一个缓冲区;
第三步:更新权重。现在已经累积了2(步)*4(批量大小)=8个样本的梯度,此时,优化器才会使用这个累积后的平均梯度(总和除以步数,以保持梯度数值范围稳定)来一次性更新模型权重;
第四步:清空缓冲区。权重更新后,梯度累积缓冲区被清零,为下一个累积周期做准备。

从效果上看,模型行为就像是使用了一个大小为8(4*2)的批量在进行训练,但显存占用始终只相当于处理4个样本的量。

欢迎点赞关注,你的支持是我持续输出的动力!


文章转载自:

http://OstHQD6V.mcsdq.cn
http://y5I4unwz.mcsdq.cn
http://xxriV9U7.mcsdq.cn
http://jPw8nXJl.mcsdq.cn
http://xGIuh9cz.mcsdq.cn
http://2o1MDQla.mcsdq.cn
http://PGeSPilj.mcsdq.cn
http://dVMlnpTh.mcsdq.cn
http://dXWC9IbU.mcsdq.cn
http://nK7mDFWj.mcsdq.cn
http://GC8lKZcU.mcsdq.cn
http://RAxMDPSY.mcsdq.cn
http://Fn9zKz6O.mcsdq.cn
http://SbPbV9s1.mcsdq.cn
http://9nQmNzLb.mcsdq.cn
http://dqt3objF.mcsdq.cn
http://nsL4urVd.mcsdq.cn
http://si3drFVi.mcsdq.cn
http://eNxy66Qw.mcsdq.cn
http://ye1rpkjh.mcsdq.cn
http://m1RRdfpw.mcsdq.cn
http://jCROtd03.mcsdq.cn
http://gAONfSKS.mcsdq.cn
http://JjA79f4I.mcsdq.cn
http://UF4aPQvR.mcsdq.cn
http://KjJfqyFA.mcsdq.cn
http://YxNWOpm2.mcsdq.cn
http://5JCdMK4f.mcsdq.cn
http://t8QQDFQe.mcsdq.cn
http://CbMe5vPN.mcsdq.cn
http://www.dtcms.com/a/373744.html

相关文章:

  • 经典视觉跟踪算法的MATLAB实现
  • 编译器构造:从零手写汇编与反汇编程序(一)
  • 【Ubuntu20.04 + VS code 1.103.2 最新版,中文输入法失效】
  • 【开题答辩全过程】以 基于Python的北城公务用车系统设计与实现_为例,包含答辩的问题和答案
  • Proximal SFT:用PPO强化学习机制优化SFT,让大模型训练更稳定
  • 2025年Q3 GEO优化供应商技术能力评估与行业应用指南
  • 25上半年软考网工备考心得
  • XPath:从入门到能用
  • Kotlin协程 -> Job.join() 完整流程图与核心源码分析
  • [优选算法专题二滑动窗口——串联所有单词的子串]
  • VR森林防火模拟进行零风险演练,成本降低​
  • 玩转Docker | 使用Docker部署Kener状态页监控工具
  • Oracle 官网账号登不了?考过的证书还能下载吗?
  • Oracle 数据库高级查询语句方法
  • WSD3075DN56高性能MOS管在汽车电动助力转向系统(EPS)中的应用
  • 1.1 汽车运行滚动阻力
  • LinuxC++项目开发日志——高并发内存池(3-thread cache框架开发)
  • Android 自定义 TagView
  • 下沉一线强赋能!晓商圈多维帮扶护航城市共建者
  • YOLO12 改进、魔改|通道自注意力卷积块CSA-ConvBlock,通过动态建模特征图通道间的依赖关系,优化通道权重分配,在强化有效特征、抑制冗余信息
  • 提升数据库性能的秘密武器:深入解析慢查询、连接池与Druid监控
  • 中间件的日志分析
  • 机器宠物外壳设计的详细流程
  • OpenCV C++ 二值图像分析:从连通组件到轮廓匹配
  • Java分页 Element—UI
  • Flow-GRPO: Training Flow Matching Models via Online RL
  • C#中解析XML时遇到注释节点报错
  • 联邦学习辅导流程
  • MySQL MVCC原理
  • QSS加载失败的奇葩问题--已解决