当前位置: 首页 > news >正文

PyTorch神经网络工具箱(优化器)

优化器

PyTorch常用的优化方法都封装在torch.optim里面,其设计很灵活,可以扩展为自定
义的优化方法。所有的优化方法都是继承了基类optim.Optimizer,并实现了自己的优化步
骤。最常用的优化算法就是梯度下降法及其各种变种,后续章节我们将介绍各种算法的原
理,这类优化算法通过使用参数的梯度值更新参数。

3.2节使用的随机梯度下降法(SGD)就是最普通的优化器,一般SGD并说没有加速
效果,3.2节使用的SGD包含动量参数Momentum,它是SGD的改良版。

我们结合3.2小结内容,说明使用优化器的一般步骤为:

(1)建立优化器实例

导入optim模块,实例化SGD优化器,这里使用动量参数momentum(该值一般在
(0,1)之间),是SGD的改进版,效果一般比不使用动量规则的要好。

import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

以下步骤在训练模型的for循环中。

(2)向前传播

把输入数据传入神经网络Net实例化对象model中,自动执行forward函数,得到out输
出值,然后用out与标记label计算损失值loss。

out = model(img)
loss = criterion(out, label)

(3)清空梯度

缺省情况梯度是累加的,在梯度反向传播前,先需把梯度清零。

optimizer.zero_grad()

(4)反向传播

基于损失值,把梯度进行反向传播。

loss.backward()

(5)更新参数

基于当前梯度(存储在参数的.grad属性中)更新参数。

optimizer.step()

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# ===== 1. 修复未定义变量 =====
# 定义模型 (示例: 简单CNN)
class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.fc = nn.Linear(16 * 30 * 30, 10)  # 假设输入32x32图像def forward(self, x):x = torch.relu(self.conv1(x))x = x.view(x.size(0), -1)return self.fc(x)model = SimpleCNN()  # [!修复model!]# 定义超参数
EPOCHS = 100  # [!修复EPOCHS!]
LR = 0.01
MOMENTUM = 0.9  # 注意: 使用数值而非物理量# 创建示例数据集 (实际应替换为真实数据)
dummy_imgs = torch.randn(1000, 3, 32, 32)  # [!修复img数据源!]
dummy_labels = torch.randint(0, 10, (1000,))  # [!修复label数据源!]
dataset = TensorDataset(dummy_imgs, dummy_labels)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)  # [!修复dataloader!]criterion = nn.CrossEntropyLoss()  # [!修复criterion!]
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)# ===== 2. 增强型训练循环 =====
for epoch in range(EPOCHS):for batch in dataloader:img, label = batch  # [!解包获取img/label!]# 前向传播out = model(img)loss = criterion(out, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch {epoch + 1}/{EPOCHS} | Loss: {loss.item():.4f}') 
http://www.dtcms.com/a/337719.html

相关文章:

  • buuctf:护网杯_2018_gettingstart、oneshot_tjctf_2016
  • llamafactory使用qlora训练
  • VectorDB+FastGPT一站式构建:智能知识库与企业级对话系统实战
  • 使用LLaMA-Factory对大模型进行微调-详解
  • OSG+Qt —— 笔记2- Qt窗口绘制棋盘及模型周期运动(附源码)
  • linux:告别SSH断线烦恼,Screen命令核心使用指南
  • 第四章:大模型(LLM)】07.Prompt工程-(1)Prompt 原理与基本结构
  • 大数据分析-读取文本文件内容进行词云图展示
  • Zephyr 中的 bt_le_per_adv_set_data 函数的介绍和应用方法
  • [机器学习]09-基于四种近邻算法的鸢尾花数据集分类
  • 具身智能赋能轮椅机器人的认知革命与人机共生新范式
  • 【软考架构】第4章 信息安全的抗攻击技术
  • 从「行走」到「思考」:机器人进化之路与感知—决策链路的工程化实践
  • 微电网管控系统中python多线程缓存与SQLite多数据库文件连接池实践总结(含源码)
  • 安川YASKAWA焊接机器人保护气智能节气阀
  • 蓝牙 GFSK RX Core 架构解析
  • Linux下的软件编程——IPC机制
  • 重复(Repeat)和迭代(Iteration)区别、递归(Recursion)
  • 超级云平台:重构数字生态的“超级连接器“
  • 想找出版社出书?这样选就对了!
  • 哈工深无人机目标导航新基准!UAV-ON:开放世界空中智能体目标导向导航基准测试
  • 【论文阅读】-《GeoDA: a geometric framework for black-box adversarial attacks》
  • 基于Flink CDC实现联系人与标签数据实时同步至ES的实践
  • 后台管理系统-6-vue3之mockjs模拟和axios请求数据
  • python UV虚拟环境项目搭建
  • 和芯星通携手思博伦通信,测试验证系列导航定位芯片/模块符合GB/T 45086.1标准
  • 学习stm32 感应开关盖垃圾桶
  • 用 Python 实现一个“小型 ReAct 智能体”:思维链 + 工具调用 + 环境交互
  • 软件测试覆盖率:真相与实践
  • unity实现背包拖拽排序