当前位置: 首页 > news >正文

PyTorch——优化器(9)

优化器根据梯度调整参数,以达到降低误差

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)# 定义卷积神经网络模型
class TY(nn.Module):def __init__(self):super(TY, self).__init__()# 构建网络结构:3个卷积层+池化层组合,2个全连接层self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5MaxPool2d(2),                   # 最大池化,步长2Conv2d(32, 32, 5, padding=2),   # 第二层卷积MaxPool2d(2),                   # 第二次池化Conv2d(32, 64, 5, padding=2),   # 第三层卷积MaxPool2d(2),                   # 第三次池化Flatten(),                      # 将多维张量展平为向量Linear(1024, 64),               # 全连接层,输入1024维,输出64维Linear(64, 10),                 # 输出层,10个类别对应10个输出)def forward(self, x):# 定义前向传播路径x = self.model1(x)return x# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)# 训练20个完整轮次
for epoch in range(20):running_loss = 0.0  # 初始化本轮累计损失# 遍历数据加载器中的每个批次for data in dataloader:imgs, targets = data  # 获取图像和标签outputs = ty(imgs)    # 前向传播result_loss = loss(outputs, targets)  # 计算损失optim.zero_grad()     # 梯度清零,防止累积result_loss.backward()  # 反向传播计算梯度optim.step()          # 更新模型参数running_loss += result_loss  # 累加损失值# 打印本轮训练的累计损失print(f"Epoch {epoch+1}, Loss: {running_loss}")

相关文章:

  • Java设计模式深度解析:策略模式的核心原理与实战应用
  • 完成一个可交互的k8s管理平台的页面开发
  • MYSQL(二) ---MySQL 8.4 新特性与变量变更
  • Flutter快速上手,入门教程
  • 从OCR到Document Parsing,AI时代的非结构化数据处理发生了什么改变?
  • Docker 部署 Python 的 Flask项目
  • Apache POI操作Excel详解
  • 从上下文学习和微调看语言模型的泛化:一项对照研究 -附录
  • 一台电脑联网如何共享另一台电脑?网线方式
  • 一键更新依赖全指南:Flutter、Node.js、Kotlin、Java、Go、Python 等主流语言全覆盖
  • NLP中的input_ids是什么?
  • VSCode 工作区配置文件通用模板(CMake + Ninja + MinGW/GCC 编译器 的 C++ 或 Qt 项目)
  • 在compose中的Canvas用kotlin显示多数据波形闪烁的问题
  • 国产化Word处理控件Spire.Doc教程:Java实现HTML 转Word自动化
  • c#开发AI模型对话
  • Axios 取消请求的演进:CancelToken vs. AbortController
  • AWS中国区IAM相关凭证自行管理策略(只读CodeCommit版)
  • bug:undefined is not iterable (cannot read property Symbol(Symbol.iterator))
  • AI炼丹日志-28 - Audiblez 将你的电子书epub转换为音频mp3 做有声书
  • CATIA-CAD 拆图
  • 企业培训师资格证报考2023/seo优化网站推广
  • 建筑工程 技术支持 东莞网站建设/自媒体是如何赚钱的
  • 开网络公司做网站挣钱吗/网络营销战略
  • 自己做网站写文章/北京网络优化推广公司
  • 上海室内设计有限公司/潜江seo
  • 小米网站建设/余姚网站如何进行优化