当前位置: 首页 > news >正文

PyTorch——优化器(9)

优化器根据梯度调整参数,以达到降低误差

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)# 定义卷积神经网络模型
class TY(nn.Module):def __init__(self):super(TY, self).__init__()# 构建网络结构:3个卷积层+池化层组合,2个全连接层self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5MaxPool2d(2),                   # 最大池化,步长2Conv2d(32, 32, 5, padding=2),   # 第二层卷积MaxPool2d(2),                   # 第二次池化Conv2d(32, 64, 5, padding=2),   # 第三层卷积MaxPool2d(2),                   # 第三次池化Flatten(),                      # 将多维张量展平为向量Linear(1024, 64),               # 全连接层,输入1024维,输出64维Linear(64, 10),                 # 输出层,10个类别对应10个输出)def forward(self, x):# 定义前向传播路径x = self.model1(x)return x# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)# 训练20个完整轮次
for epoch in range(20):running_loss = 0.0  # 初始化本轮累计损失# 遍历数据加载器中的每个批次for data in dataloader:imgs, targets = data  # 获取图像和标签outputs = ty(imgs)    # 前向传播result_loss = loss(outputs, targets)  # 计算损失optim.zero_grad()     # 梯度清零,防止累积result_loss.backward()  # 反向传播计算梯度optim.step()          # 更新模型参数running_loss += result_loss  # 累加损失值# 打印本轮训练的累计损失print(f"Epoch {epoch+1}, Loss: {running_loss}")

http://www.dtcms.com/a/230263.html

相关文章:

  • Java设计模式深度解析:策略模式的核心原理与实战应用
  • 完成一个可交互的k8s管理平台的页面开发
  • MYSQL(二) ---MySQL 8.4 新特性与变量变更
  • Flutter快速上手,入门教程
  • 从OCR到Document Parsing,AI时代的非结构化数据处理发生了什么改变?
  • Docker 部署 Python 的 Flask项目
  • Apache POI操作Excel详解
  • 从上下文学习和微调看语言模型的泛化:一项对照研究 -附录
  • 一台电脑联网如何共享另一台电脑?网线方式
  • 一键更新依赖全指南:Flutter、Node.js、Kotlin、Java、Go、Python 等主流语言全覆盖
  • NLP中的input_ids是什么?
  • VSCode 工作区配置文件通用模板(CMake + Ninja + MinGW/GCC 编译器 的 C++ 或 Qt 项目)
  • 在compose中的Canvas用kotlin显示多数据波形闪烁的问题
  • 国产化Word处理控件Spire.Doc教程:Java实现HTML 转Word自动化
  • c#开发AI模型对话
  • Axios 取消请求的演进:CancelToken vs. AbortController
  • AWS中国区IAM相关凭证自行管理策略(只读CodeCommit版)
  • bug:undefined is not iterable (cannot read property Symbol(Symbol.iterator))
  • AI炼丹日志-28 - Audiblez 将你的电子书epub转换为音频mp3 做有声书
  • CATIA-CAD 拆图
  • 【从零学习JVM|第二篇】字节码文件
  • Kubernetes 网络方案:Flannel 插件全解析
  • MyBatis-Plus LambdaQuery 高级用法:JSON 路径查询与条件拼接的全场景解析
  • 判断:有那种使用了局部变量的递归过程在转换成非递归过程时才必须使用栈
  • 【从前端到后端导入excel文件实现批量导入-笔记模仿芋道源码的《系统管理-用户管理-导入-批量导入》】
  • 信号与系统汇总
  • OpenCV计算机视觉实战(10)——形态学操作详解
  • 【WPF】WPF 项目实战:构建一个可增删、排序的光源类型管理界面(含源码)
  • 2025 5 月 学习笔记
  • 705SJBH超市库存管理系统文献综述