当前位置: 首页 > news >正文

全面详解 PyTorch 中的优化器

全面地详解 PyTorch 中的优化器。优化器是深度学习模型训练的核心组成部分,它决定了模型参数如何根据损失函数的梯度进行更新。

一、 优化器是什么?

在 PyTorch 中,优化器是一个封装了各种优化算法的类,它的核心职责是:
根据计算得到的梯度,更新模型的可学习参数(即 requires_grad=True 的参数),以最小化损失函数。

二、 优化器的基本使用范式

所有优化器的使用都遵循一个标准流程,这也是 PyTorch 设计优雅的地方:

import torch
import torch.nn as nn# 1. 定义一个模型
model = MyModel()# 2. 定义损失函数
criterion = nn.CrossEntropyLoss()# 3. 定义优化器,并将模型的参数传递给它
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 进入训练循环
for epoch in range(num_epochs):for data, labels in dataloader:# 4. 前向传播outputs = model(data)loss = criterion(outputs, labels)# 5. 清零梯度(非常重要!)optimizer.zero_grad()# 6. 反向传播,计算梯度loss.backward()# 7. 执行一步参数更新optimizer.step()# (可选)8. 学习率调度# scheduler.step()

关键步骤解释:

  • optimizer.zero_grad():在每次反向传播前,必须将优化器中所有参数的梯度重置为零。因为 PyTorch 的梯度是累加的,如果不清零,下一次.backward()时梯度会与当前梯度叠加。
  • loss.backward():执行反向传播,通过自动微分计算每个参数的梯度,并存储在 parameter.grad 中。
  • optimizer.step():根据优化算法的规则,利用 .grad 中的梯度值,执行一次参数更新。

三、 主流优化器详解

PyTorch 在 torch.optim 模块中提供了丰富的优化器。我们按类别来详解最常见的几种。

1. 随机梯度下降及其变种

a. 标准 SGD
最基础的优化算法。

optimizer = torch.optim.SGD(model.parameters(), lr=0.01,         # 学习率momentum=0,       # 动量,默认为0dampening=0, weight_decay=0,   # L2 正则化系数nesterov=False    # 是否使用 Nesterov 动量
)
  • 原理param = param - lr * param.grad
  • 特点: 简单,但容易在沟壑中震荡,收敛慢。

b. SGD with Momentum
引入“动量”概念,模拟物理中的惯性。

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  • 原理
    • v = momentum * v - lr * gv 是速度,g 是当前梯度)
    • param = param + v
  • 作用
    • 加速收敛:在相关梯度方向上积累速度,更新更快。
    • 减少震荡:有助于穿过狭窄的沟壑和局部最优点。

c. SGD with Nesterov Momentum
Nesterov 是 Momentum 的改进版,具有“前瞻性”。

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
  • 原理:先根据累积的速度“跳一步”,然后在这个“未来”位置计算梯度,再修正。
  • 公式v = momentum * v - lr * gradient(param + momentum * v)
  • 作用: 比标准 Momentum 在理论上收敛性更好,实践中也常用。
2. 自适应学习率优化器

这类优化器为每个参数自动调整学习率,是当前最主流的类别。

a. AdaGrad
为不频繁出现的参数赋予更大的学习率。

optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01, weight_decay=0)
  • 原理:累积历史梯度的平方和,学习率除以这个累积和的平方根。
  • 公式cache += g^2param = param - (lr / (sqrt(cache) + eps)) * g
  • 特点
    • 优点:适合处理稀疏数据。
    • 缺点:累积和会持续增长,导致学习率过早、过度衰减,最终无法学习。

b. RMSprop
为了解决 AdaGrad 学习率急剧下降的问题,引入了衰减因子。

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99,    # 平滑常数,相当于 Momentum 中的 betaeps=1e-8, weight_decay=0, momentum=0      # 可以额外加入动量
)
  • 原理:使用指数加权移动平均来累积梯度平方,而不是简单求和。
  • 公式cache = alpha * cache + (1 - alpha) * g^2param = param - (lr / (sqrt(cache) + eps)) * g
  • 特点: 是 AdaGrad 的改进,在实践中效果很好,尤其在 RNN 中。

c. Adam
目前最流行、最通用的优化器,结合了 Momentum 和 RMSprop 的思想。

optimizer = torch.optim.Adam(model.parameters(), lr=0.001,              # 通常使用较小的学习率betas=(0.9, 0.999),    # (一阶矩估计的衰减率, 二阶矩估计的衰减率)eps=1e-8, weight_decay=0, amsgrad=False
)
  • 原理
    1. 计算梯度的一阶矩(均值,有偏)和二阶矩(未中心化的方差,有偏)。
      • m = beta1 * m + (1 - beta1) * g
      • v = beta2 * v + (1 - beta2) * g^2
    2. 对一阶和二阶矩进行偏差校正,以解决初始零偏问题。
      • m_hat = m / (1 - beta1^t)t 是时间步)
      • v_hat = v / (1 - beta2^t)
    3. 更新参数: param = param - lr * m_hat / (sqrt(v_hat) + eps)
  • 特点
    • 优点:通常收敛快,对超参数不敏感(除了学习率),是很好的默认选择。
    • AMSGrad 变体amsgrad=True,使用 v 的历史最大值,可以解决某些收敛性问题,但并不总是有效。

d. AdamW
Adam 的一个改进,正确地实现了权重衰减

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
  • 与 Adam 的区别
    • Adam(L2正则化): 在梯度中直接加入 weight_decay * param,这实际上不是真正的 L2 正则化。
    • AdamW(解耦权重衰减): 将权重衰减与梯度更新分离开,直接在参数上应用衰减: param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
  • 特点: 在许多任务上(尤其是计算机视觉)表现优于标准 Adam,推荐优先尝试

四、 如何选择优化器?

这是一个经验性问题,但有一些通用准则:

  1. 新手或默认选择AdamWAdam。它们在大多数情况下都能工作得很好。
  2. 如果追求最佳性能: 可以尝试 SGD with MomentumNesterov。虽然它需要更精细的学习率调整和更长的训练时间,但最终收敛的泛化性能有时会优于自适应方法。
  3. 处理稀疏数据: 可以考虑 AdaGrad 或其变种。
  4. 训练 RNN/LSTMRMSpropAdam 都是不错的选择。

简单建议:从 AdamW 开始,如果训练稳定但性能达不到预期,再尝试调优过的 SGD。


五、 优化器的进阶用法

1. 为不同层设置不同的超参数

通过 parameter groups 实现,非常灵活。

optimizer = torch.optim.SGD([{'params': model.base.parameters()},                  # 默认参数组{'params': model.classifier.parameters(), 'lr': 1e-3} # 自定义参数组
], lr=1e-2, momentum=0.9)# 在训练中动态修改特定组的学习率
for param_group in optimizer.param_groups:if ‘classifier’ in param_group[‘name’]: # 需要自己添加‘name’字段param_group[‘lr’] = 0.001 * 0.1
2. 学习率调度器

优化器负责更新参数,而调度器负责在训练过程中调整优化器的超参数(主要是学习率)。

from torch.optim import lr_scheduleroptimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 定义调度器
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # 每30个epoch,学习率乘0.1for epoch in range(num_epochs):# ... 一个epoch的训练 ...# 在epoch结束时更新学习率scheduler.step()

常用调度器:

  • StepLR: 固定步长衰减。
  • MultiStepLR: 在指定epoch衰减。
  • ExponentialLR: 指数衰减。
  • CosineAnnealingLR: 余弦退火,非常有效。
  • ReduceLROnPlateau动态调度,当指标(如验证损失)停止改善时降低学习率。
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
    # 在每个epoch后
    val_loss = ...
    scheduler.step(val_loss) # 传入监控的指标
    
3. 梯度裁剪

防止梯度爆炸,在 RNN 中尤其重要。在 loss.backward() 之后,optimizer.step() 之前调用。

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 按范数裁剪
# 或者
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5) # 按值裁剪

六、 优化器选择与调参经验总结

  1. 学习率是首要超参数: 如果损失不下降(LR太小)或变成NaN(LR太大),首先调整学习率。使用学习率查找器是一个好方法。
  2. Adam/AdamW 的默认参数通常很好betas=(0.9, 0.999)eps=1e-8 在绝大多数情况下不需要修改。
  3. 权重衰减很重要: 即使是很小的值(如 1e-4)也能显著影响泛化能力。
  4. 配合学习率调度器: 静态的学习率通常不是最优的,使用调度器(如余弦退火或 ReduceLROnPlateau)能带来巨大提升。
  5. 监控训练过程: 使用 TensorBoard 或 WandB 监控损失和梯度直方图,这能帮助你诊断优化问题。

通过深入理解这些优化器的工作原理和使用技巧,你就能更有信心地驾驭深度模型的训练过程,使其更快、更稳地收敛。

http://www.dtcms.com/a/458185.html

相关文章:

  • npm 扩展vite、element-plus
  • 好看的网站首页特效网页设计作品简单
  • dedecms织梦古典艺术书画书法公司企业网站源码模板网页设计茶叶网站建设
  • 网站文件名优化深圳龙华区地图
  • SystemVerilog的隐含随机约束
  • 类似站酷的网站建站网站在线考试答题系统怎么做
  • 网站备案号查询网互联网企业概念
  • [01] Qt的UI框架选择和对比
  • 吴恩达机器学习课程(PyTorch 适配)学习笔记:3.3 推荐系统全面解析
  • 劳动服务公司网站源码线上销售模式有哪些
  • 青岛北京网站建设公司哪家好个人在湖北建设厅网站申请强制注销
  • 微网站策划方案wordpress做app下载文件
  • 建设网站德州百度招聘官网首页
  • 基于GA-SVM的织物瑕疵种类识别算法matlab仿真,包含GUI界面
  • IT 疑难杂症诊疗室:破解数字世界的 “疑难杂症”
  • 做网站用笔记本做服务器吗驾校网站建设方案
  • 绍兴外贸网站建设嘉祥网站建设
  • 机器视觉Halcon3D中create_pose的作用
  • 个人博客建站wordpress网站建设岗位能力评估表
  • 建网站哪家好绿色建筑网站
  • 万网域名价格重庆百度搜索排名优化
  • CPP 内存管理
  • 专做网页的网站设计网站大全湖南岚鸿网站大全
  • 小公司网站怎么建一级水蜜桃
  • Java25 新特性介绍
  • 珠海做网站找哪家好在线网站推荐几个
  • 倍增:64位整除法
  • 钓鱼网站开发系列教程2013电子商务网站建设
  • Python协程详解:从并发编程基础到高性能服务器开发
  • 以太网数据包协议字段全解析(进阶补充篇)