当前位置: 首页 > news >正文

深度学习优化器进阶:从SGD到AdamW,不同优化器的适用场景

点击AladdinEdu,同学们用得起的【H卡】算力平台”,注册即送-H卡级别算力沉浸式云原生的集成开发环境80G大显存多卡并行按量弹性计费教育用户更享超低价


引言:优化器——深度学习的“导航系统”

在深度学习的浩瀚星海中,我们常常将神经网络模型比作一艘功能强大的宇宙飞船,海量的训练数据是其无尽的燃料,而优化器(Optimizer),则是这艘飞船的导航与推进系统。它的职责无比关键:如何根据当前的位置(模型状态)和观测到的星图(损失函数的梯度),智能地调整飞行的方向和速度(更新模型参数),最终高效、准确地带领我们抵达目的地(损失函数的最小值点)。

一个拙劣的导航系统可能会让飞船在原地打转(收敛缓慢)、陷入引力漩涡(陷入局部最优点)甚至直接失控(训练发散)。而一个优秀的导航系统,则能利用一切可用信息,实现平稳、快速且精准的航行。

从最古老的随机梯度下降(SGD)到如今大模型训练标配的AdamW,优化器的发展史就是一段人类尝试为深度学习飞船打造更智能“导航系统”的演进史。本文将带你深入这段历史,剖析动量、自适应学习率等核心原理,并最终解释为何AdamW能成为当今大模型训练的王者。


第一章:基石与起点——随机梯度下降(SGD)

1.1 核心思想:最速下降法

随机梯度下降(Stochastic Gradient Descent, SGD)是最朴素、最直观的优化方法。其核心思想源于最速下降法:函数的梯度方向指明了其值增长最快的方向,那么负梯度方向就是函数值下降最快的方向。

参数更新公式:
θ = θ - η * ∇J(θ)
其中:

  • θ:模型参数。
  • η:学习率(Learning Rate),决定了每次更新的步长,是超参数。
  • ∇J(θ):损失函数 J 关于参数 θ 的梯度。

在深度学习中,我们很少使用整个训练集(Batch)来计算精确梯度,因为计算成本太高。取而代之的是使用一个迷你批次(Mini-batch)的数据来估计梯度,这就是“随机”一词的由来,它引入了噪声,但大大提高了计算效率。

1.2 SGD的明显缺陷

尽管SGD简单有效,但它的问题也显而易见:

  1. 容易陷入局部最优点和鞍点:尤其是在高维空间中,鞍点(saddle point)的数量远多于局部最优点。SGD的更新方向完全依赖于当前点的梯度,如果在鞍点(梯度为0但非最优点),SGD就无法逃脱。
  2. 学习率选择困难:学习率 η 是一个全局的、固定的值。
    • 如果设置得太小,收敛速度会非常慢,训练时间漫长。
    • 如果设置得太大,可能会在最优点附近震荡,甚至无法收敛(发散)。
  3. 梯度方向的不稳定性:由于使用mini-batch估计梯度,梯度本身带有噪声,更新方向会频繁变化,导致优化路径曲折蜿蜒,收敛过程波动大。

为了解决这些问题,研究者们提出了两大核心改进思路:动量(Momentum)自适应学习率(Adaptive Learning Rate)


第二章:惯性之力——动量法(Momentum)

2.1 物理世界的启发:小球下山

想象一下,一个小球从山顶滚下。它不会仅仅根据当前地面的斜率立即改变方向,而是会由于惯性(动量)沿着之前的方向继续滚动,从而可以更平滑地越过小坑洼(局部最优点)和狭窄的通道(鞍点)。

动量法正是模拟了这一物理过程。它不仅考虑当前的梯度,还会积累之前的更新方向,形成一个“速度”项。

2.2 带动量的SGD(SGDM)

参数更新公式:

v = γ * v - η * ∇J(θ)  // 更新速度(引入动量)
θ = θ + v               // 用速度更新参数

其中:

  • v:速度向量,初始为0。它累积了过去的梯度信息。
  • γ:动量系数(通常设为0.9),决定了过去速度对当前更新的影响程度。

优势:

  1. 加速收敛:在梯度方向一致的维度上,速度向量会不断累积,更新幅度越来越大,从而加速收敛。
  2. 抑制震荡:在梯度方向频繁改变的维度上,由于动量的存在,更新方向不会发生剧烈变化,使得优化路径更加平滑,更容易逃离鞍点和局部最优点。
  3. 缓解梯度噪声:动量对mini-batch的梯度噪声起到了平滑作用,使得更新方向更接近真实的全数据集梯度方向。

动量法显著提升了SGD的性能,使其成为很长一段时间内计算机视觉等领域的默认优化器。但它仍然没有解决学习率需要手工精细调整的问题。


第三章:因材施教——自适应学习率算法

动量法让所有参数共享同一个学习率,并拥有相同的更新速度。这显然不是最优的。想象一下,有些参数(如频繁出现的特征的权重)我们已经学到了很多,希望它慢点更新;而有些参数(如罕见特征的权重)我们知之甚少,希望它快点更新。

自适应学习率算法应运而生,其核心思想是:为每个参数分配不同的、自适应调整的学习率

3.1 AdaGrad(Adaptive Gradient)

AdaGrad的思路是:为那些不频繁更新的参数(梯度出现次数少)赋予更大的学习率,为那些频繁更新的参数(梯度出现次数多)赋予更小的学习率。

实现方式:它会对每个参数的历史梯度平方进行累加。

参数更新公式:

cache = cache + (∇J(θ))²
θ = θ - (η / (√cache + ɛ)) * ∇J(θ)

其中:

  • cache:与参数 θ 同形的向量,累积了该参数所有历史梯度的平方和。
  • ɛ:一个极小值(如1e-7),防止分母为零。

优势与缺陷

  • 优势:在稀疏数据场景下(如自然语言处理)效果很好,因为它为不常见的特征赋予了更大的更新步长。
  • 缺陷cache 会随着训练进行单调递增,导致学习率会持续衰减,最终变得无限小,以至于在训练后期模型可能完全停止更新。
3.2 RMSProp(Root Mean Square Propagation)

为了克服AdaGrad学习率过早衰减的问题,RMSProp引入了一个衰减系数 ρ,将对历史梯度平方的累加改为指数移动平均(EMA)。这意味着它更关注最近一段时间的梯度 history,而不是整个历史。

参数更新公式:

cache = ρ * cache + (1 - ρ) * (∇J(θ))²  // 指数移动平均
θ = θ - (η / (√cache + ɛ)) * ∇J(θ)

其中:

  • ρ:衰减速率(通常设为0.9),控制着历史信息的重要性。

优势
RMSProp解决了AdaGrad学习率衰减过快的问题,成为了处理非平稳目标(如深度学习损失函数)的一个非常有效的算法,至今仍在许多场合使用。

3.3 Adam(Adaptive Moment Estimation)

2014年提出的Adam算法,可以看作是动量法(Momentum)RMSProp 的完美结合。它同时考虑了梯度的一阶矩(均值,提供动量)和二阶矩(未中心化的方差,提供自适应学习率),并进行了偏差校正。

参数更新公式:

// 更新一阶矩(动量)和二阶矩(方差)的估计
m = β1 * m + (1 - β1) * ∇J(θ)     // 一阶矩,类似动量
v = β2 * v + (1 - β2) * (∇J(θ))²   // 二阶矩,类似RMSProp的cache// 偏差校正:由于m和v初始为0,在训练初期会偏向于0
m_hat = m / (1 - β1^t)            // t是迭代次数
v_hat = v / (1 - β2^t)// 更新参数
θ = θ - η * m_hat / (√v_hat + ɛ)

其中:

  • m:一阶矩向量(估计梯度的均值)。
  • v:二阶矩向量(估计梯度平方的均值)。
  • β1, β2:矩估计的指数衰减速率(通常分别设为0.9和0.999)。
  • t:训练迭代的步数。

为什么Adam如此强大?

  1. 结合双重优势:它既像动量法一样积累了梯度方向,使得优化路径平滑且能加速收敛;又像RMSProp一样为每个参数自适应调整学习率。
  2. 偏差校正:确保了在训练初期,估计值(m_hat, v_hat)不会偏向于初始值0,使得训练初期更加稳定。
  3. 超参数鲁棒性:默认的超参数(β1=0.9, β2=0.999, ε=1e-8)在绝大多数情况下都表现良好,几乎可以“开箱即用”。

Adam因其卓越的性能和稳定性,迅速成为深度学习领域最流行、最通用的优化器,适用于绝大多数任务。


第四章:王者之选——AdamW与权重衰减

就在大家以为Adam是终极答案时,研究者们发现了一个微妙但重要的问题:Adam与权重衰减(L2正则化)的结合并不完美

4.1 L2正则化 vs. 权重衰减

在标准的SGD中,L2正则化权重衰减(Weight Decay) 是等价的。

  • SGD with L2正则化:损失函数 J(θ) = 原始损失 + (λ/2) * ||θ||²,梯度为 ∇J(θ) = ∇原始损失 + λ * θ。
  • SGD with 权重衰减:参数更新为 θ = θ - η * ∇原始损失 - η * λ * θ

可以看到,在SGD中,这两种写法最终得到的结果是完全一样的:- η * λ * θ 就是权重衰减项。

然而,在自适应学习率算法(如Adam)中,它们不再等价!

  • Adam with L2正则化:L2项会被加到损失中,因此也会被计入梯度 ∇J(θ)。这个梯度会被Adam的自适应学习率机制处理:η / √v_hat。这意味着,对于不同的参数,权重衰减的实际效果会被自适应地缩放。这违背了权重衰减的初衷——对所有权重进行同等程度的衰减。
  • 这会导致正则化效果不一致,甚至可能使得权重较大的参数反而得不到足够的衰减,不利于模型泛化。
4.2 AdamW的解决方案:解耦权重衰减

AdamW(Adam with Weight decay)的提出解决了这一问题。它的核心思想是:将权重衰减与梯度更新完全解耦

  • 不再将L2项加入损失函数
  • 而是在计算完自适应梯度更新后,直接、显式地将一个权重衰减项加到更新后的参数上。

AdamW的更新公式(简化版):

// 自适应更新部分(与Adam相同,但损失函数不含L2)
θ = θ - η * m_hat / (√v_hat + ɛ)// 解耦的权重衰减
θ = θ - η * λ * θ

(注:实际实现中,权重衰减步骤通常与参数更新步骤合并在一起,但概念上是独立的。)

这样做的好处是:权重衰减 η * λ * θ 不再受到自适应学习率 η / √v_hat 的影响。它对所有参数都施加了完全一致、可控的衰减力度,真正实现了正则化的目的

4.3 为什么AdamW成为大模型训练的标配?

大规模模型(如GPT、BERT、ViT等)拥有巨大的容量和数以亿计的参数,极其容易过拟合。因此,正则化技术至关重要。

  1. 更有效的正则化:AdamW提供了比Adam+L2更纯粹、更直接、更有效的权重衰减方式,被证明能显著提升大模型的泛化能力,是防止过拟合的关键。
  2. 训练稳定性:解耦的设计使得超参数(尤其是学习率 η 和权重衰减系数 λ)的调节更加稳定和可预测。研究者发现了适用于AdamW的“解耦超参数调优策略”,例如,对于Transformer架构,学习率和权重衰减通常可以独立设置而不会相互干扰。
  3. 广泛的实证成功:在ImageNet、Transformer、GAN等一系列重要的模型和数据集上,AdamW都表现出了比Adam更优异的最终性能,尤其是在训练时间很长的大规模任务中。这一结果被无数论文和实践所证实。

因此,当你在阅读当今最前沿的大模型论文时(如GPT系列、LLaMA、Stable Diffusion等),几乎总能在“Training Details”一节中找到“We use AdamW as our optimizer”的字样。它已然成为了大模型训练中不可或缺的“标配”组件。


第五章:优化器选择指南与总结

没有放之四海而皆准的“最佳”优化器。理解其原理才能做出最适合的选择。

优化器核心思想优点缺点/适用场景
SGD负梯度方向下降简单,理论清晰易陷入鞍点,收敛慢,依赖精细调参
SGDM引入动量加速和平滑加速收敛,减轻震荡,易于逃离鞍点学习率仍需手动调整;CNN等传统CV任务
AdaGrad为稀疏参数增大学习率适合稀疏数据(如NLP)学习率过早衰减至零,后期无法学习
RMSProp指数平均解决衰减问题解决了AdaGrad的缺陷,适应非平稳目标RNN等循环网络效果很好
AdamMomentum + RMSProp + 偏差校正默认参数好,收敛快,适用性广与L2正则化结合不佳,泛化性可能稍差
AdamWAdam + 解耦权重衰减泛化能力极强,超参数稳定大模型训练、Transformer架构、LLM的绝对首选

总结与建议:

  • 新手入门/小规模模型:可以从Adam开始,它简单且效果不错。
  • 计算机视觉(CNN)带动量的SGD(SGDM) 经过精心调参后,仍然可能达到比Adam更优的最终精度,但需要更多的调参成本。
  • 自然语言处理/循环神经网络RMSPropAdam系列一直是传统上的好选择。
  • 大规模深度学习(Transformer, LLM, 预训练模型)AdamW是目前无可争议的最佳选择和工业标准。如果你要训练BERT或微调LLaMA,AdamW是你的不二之选。
  • 理论兴趣NAdam(Nesterov加速的Adam)、AMSGrad等变体也值得了解,它们试图解决Adam可能存在的收敛性问题,但在实践中的提升并不总是显著。

优化器的演进历程,体现了深度学习领域从直觉到理论、从粗糙到精细、从通用到专用的不断发展。理解从SGD到AdamW的每一步改进背后的“为什么”,远比记住几个公式更重要。这能帮助你在面对新的模型和任务时,做出更明智、更自信的选择,为你自己的“深度学习飞船”装上最合适的导航系统。


点击AladdinEdu,同学们用得起的【H卡】算力平台”,注册即送-H卡级别算力沉浸式云原生的集成开发环境80G大显存多卡并行按量弹性计费教育用户更享超低价

http://www.dtcms.com/a/394329.html

相关文章:

  • C++ 之 【C++的IO流】
  • truffle学习笔记
  • 现代循环神经网络
  • vlc播放NV12原始视频数据
  • ThinkPHP8学习篇(七):数据库(三)
  • 链家租房数据爬虫与可视化项目 Python Scrapy+Django+Vue 租房数据分析可视化 机器学习 预测算法 聚类算法✅
  • MQTT协议知识点总结
  • C++ 类和对象·其一
  • TypeScript里的类型声明文件
  • 【LeetCode - 每日1题】设计电影租借系统
  • Java进阶教程,全面剖析Java多线程编程,线程安全,笔记12
  • DCC-GARCH模型与代码实现
  • 实验3掌握 Java 如何使用修饰符,方法中参数的传递,类的继承性以及类的多态性
  • 【本地持久化】功能-总结
  • 深入浅出现代FPU浮点乘法器设计
  • LinkedHashMap 访问顺序模式
  • 破解K个最近点问题的深度思考与通用解法
  • 链式结构的特性
  • 报表1-创建sql函数get_children_all
  • 9月20日 周六 农历七月廿九 哪些属相需要谨慎与调整?
  • godot实现tileMap地图
  • 【Unity+VSCode】NuGet包导入
  • QEMU虚拟机设置网卡模式为桥接,用xshell远程连接
  • Week 17: 深度学习补遗:Boosting和量子逻辑门
  • 【论文速递】2025年第13周(Mar-23-29)(Robotics/Embodied AI/LLM)
  • Webpack进阶配置
  • 【LeetCode 每日一题】3227. 字符串元音游戏
  • 【图像算法 - 26】使用 YOLOv12 实现路面坑洞智能识别:构建更安全的智慧交通系统
  • 009 Rust函数
  • IT疑难杂症诊疗室