当前位置: 首页 > news >正文

深度学习实战(基于pytroch)系列(八)softmax回归基于pytorch的代码实现

pytorch的代码实现

    • 导包
    • 数据加载
    • 定义模型
      • nn.Flatten() 工作机制
    • 初始化权重
    • 损失函数和优化器
    • 训练模型

上一节我们已经实现softmax回归从零开始使用python代码实现,这节我们用pytorch框架来实现,由于pytorch已经实现了大部分功能,所以写起来代码会十分简洁。

导包

这一步和上一节softmax回归从零开始使用python代码实现基本一致

import torch
from torch import nn
from torch import optim
import torchvision
import torchvision.transforms as transforms

数据加载

这一步和上一节softmax回归从零开始使用python代码实现基本一致

batch_size = 256
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='../data', train=False, download=True, transform=transform)train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

定义模型

softmax回归的输出层是一个全连接层。因此,我们添加一个输出个数为10的全连接层。我们使用均值为0、标准差为0.01的正态分布随机初始化模型的权重参数。使用Sequential来构建

net = nn.Sequential(nn.Flatten(),nn.Linear(28*28, 10)
)

nn.Flatten() 工作机制

默认(start_dim=1)

import torch
import torch.nn as nn
test_cases = [torch.randn(4),           # 1D: (4,)torch.randn(4, 3),        # 2D: (4, 3)  torch.randn(4, 3, 5),     # 3D: (4, 3, 5)torch.randn(4, 3, 5, 2),  # 4D: (4, 3, 5, 2)
]flatten = nn.Flatten()print("=== nn.Flatten() 默认行为 (start_dim=1) ===")
for i, tensor in enumerate(test_cases):input_shape = tensor.shapeoutput_shape = flatten(tensor).shapedim_change = len(input_shape) - len(output_shape)print(f"输入: {input_shape} → 输出: {output_shape} | 维度变化: {dim_change}")

默认情况下,相当于batch_size不参与展平,可以通过下面输出结果,了解其工作机制:

输入: torch.Size([4]) → 输出: torch.Size([4]) | 维度变化: 0
输入: torch.Size([4,3]) → 输出: torch.Size([4, 3]) | 维度变化: 0
输入: torch.Size([4, 3, 5]) →输出: torch.Size([4, 15]) | 维度变化: -1
输入: torch.Size([4, 3, 5, 2]) → 输出:torch.Size([4, 30]) | 维度变化: -2

初始化权重

使用均值为0、标准差为0.01的正态分布初始化权重

nn.init.normal_(net[1].weight, mean=0, std=0.01)
nn.init.constant_(net[1].bias, val=0)

损失函数和优化器

定义损失函数(PyTorch的CrossEntropyLoss已经包含了softmax)

loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

训练模型

num_epochs = 5for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:# 前向传播y_hat = net(X)l = loss(y_hat, y)# 反向传播optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]# 测试准确率test_acc = 0.0test_n = 0with torch.no_grad():for X, y in test_iter:test_acc += (net(X).argmax(dim=1) == y).sum().item()test_n += y.shape[0]test_acc /= test_nprint(f'epoch {epoch + 1}, loss {train_l_sum / len(train_iter):.4f}, 'f'train acc {train_acc_sum / n:.3f}, test acc {test_acc:.3f}')
http://www.dtcms.com/a/593194.html

相关文章:

  • Redis进阶
  • 做采购 通常在什么网站看邢台市人事考试网
  • 构筑码头数字防线:视频汇聚平台EasyCVR全方位码头海岸线监管方案
  • 计算机理论基础学习Day19
  • 金仓数据库运维优化实践:从成本中心到效能引擎的转型之路
  • 招标网站哪个好用seo指什么
  • Java面试题1:Java 中 Exception 和 Error 有什么区别?
  • MacX DVD Ripper Pro for Mac v6.8.2 安装教程|MacDVD转换软件怎么安装?
  • 【rkyv】 Rust rkyv 库全面指南
  • 【Rust 探索之旅】Rust 性能优化实战指南:从编译器到并发的完整优化方案(附京东/华为云真实案例)
  • 做网站除了域名还要买什么网站搭建dns有用吗
  • 分布式虚拟 Actor 技术在码头生产调度中的应用研究
  • AI Agent设计模式 Day 6:Chain-of-Thought模式:思维链推理详解
  • Anthropic 经济指数(Economic Index)概览
  • 深圳设计网站开发网站运行模式
  • iOS崩溃日志深度分析与工具组合实战,从符号化到自动化诊断的完整体系
  • C++ Qt的QLineEdit控件详解
  • 杭州专业网站设计制作中山企业网站推广公司
  • 软考 系统架构设计师系列知识点之杂项集萃(196)
  • 基于华为昇腾CANN的自定义算子开发
  • Java iText7 PDF模板填充工具:支持多页生成、中文无坑、多占位符精确定位
  • 2025年12月英语四级大纲词汇表PDF电子版(含正序版、乱序版和默写版)
  • 蝶山网站建设樟木头仿做网站
  • 【Linux网络编程】套接字编程
  • 网站怎么做弹出表单网站竞价 英文
  • 电子电气架构 --- 当前技术水平
  • OS 特性之PendSV 异常
  • 跆拳道东莞网站建设触屏版网站开发
  • 在电脑端企业微信打开内置浏览器并调试
  • Seata原理与简单示例