当前位置: 首页 > news >正文

基于pytorch.nn模块实现softmax回归模型

课程:b站up 跟李沐学AI 的系列课程 《动手学深度学习》

笔记:本笔记基于以上课程所写,有疑问的地方,可评论区留言,或自行观看b站视频

完整代码

import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn
from torch.utils import data
from torchvision import transforms# ========= 公共函数 start =========
def get_fashion_mnist_lables(labels):"""将Fashion-MNIST数据集的数字标签转换为对应的文本标签参数:labels (list/int): 输入的标签索引返回:list: 对应的文本标签列表(如输入[0,1],返回['t-shirt', 'trouser'])"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=2.5):"""绘制图像网格的可视化函数参数:imgs (list): 待显示的图像列表(可以是Tensor或PIL.Image对象)num_rows (int): 图像网格的行数num_cols (int): 图像网格的列数titles (list): 每张图像的标题(可选,默认None)scale (float): 图像尺寸缩放因子(默认2.5)"""# 计算画布尺寸:列数*缩放因子作为宽度,行数*缩放因子作为高度figsize = (num_cols * scale, num_rows * scale)# 创建子图(展平为一维数组便于遍历)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  # 将二维坐标轴数组展平为一维for i, (ax, img) in enumerate(zip(axes, imgs)):# 处理Tensor类型的图像(需要转换为numpy数组)if torch.is_tensor(img):ax.imshow(img.numpy())  # Tensor的shape应为[C, H, W],imshow自动处理else:# 处理PIL.Image类型(直接显示)ax.imshow(img)# 隐藏坐标轴ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)# 添加标题(如果提供)if titles:ax.set_title(titles[i])plt.show()  # 显示图像def load_data_fashion_mnist(batch_size, resize=None):"""加载Fashion-MNIST数据集并返回数据加载器参数:batch_size (int): 每批数据的样本数量resize (tuple/int): 可选,调整图像尺寸(如(224,224)或224),默认不调整返回:tuple: (训练数据加载器, 测试数据加载器)"""# 定义数据预处理流程trans = [transforms.ToTensor()]  # 第一步:转换为Tensor(自动归一化到[0,1])if resize:# 如果需要调整尺寸,插入Resize变换(注意顺序:先调整尺寸再转Tensor)# Resize的参数可以是单个整数(等比缩放到该边长)或元组(宽,高)trans.insert(0, transforms.Resize(resize))# 组合所有变换为一个复合变换trans = transforms.Compose(trans)# 加载训练集(download=True表示若本地不存在则下载)mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',  # 数据存储路径train=True,  # 加载训练集download=True,  # 自动下载(首次运行时需要)transform=trans  # 应用预处理变换)# 加载测试集mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=False,  # 加载测试集download=True,transform=trans)# 创建数据加载器(将数据集分批、打乱顺序)# 训练集需要打乱顺序(shuffle=True),测试集不需要(shuffle=False)return (data.DataLoader(mnist_train, batch_size, shuffle=True),data.DataLoader(mnist_test, batch_size, shuffle=False))def accuracy(y_hat, y):"""计算模型预测的准确率(正确分类的样本比例)参数:y_hat (Tensor): 模型输出的预测结果(形状通常为[批量大小, 类别数])y (Tensor): 真实标签(形状通常为[批量大小])返回:float: 准确率(范围0-1)"""# 如果y_hat是多分类的logits(形状[批量大小, 类别数]),取每行的最大值索引作为预测类别if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 按行取最大值的索引(axis=1表示按列找最大值)# 比较预测值和真实值是否相等(类型需要一致)cmp = y_hat.type(y.dtype) == y  # 例如:y是long类型,y_hat转换为long后比较# 统计相等的数量(转换为y.dtype类型求和,避免浮点运算问题)return float(cmp.type(y.dtype).sum())class Accumulator:"""用于累积多个统计量的工具类(例如损失总和、正确数、样本数)示例:acc = Accumulator(3)  # 初始化3个统计量acc.add(1.0, 5, 10)   # 累加:第一个统计量+1.0,第二个+5,第三个+10acc[0]                # 获取第一个统计量的值(1.0)"""def __init__(self, n):"""初始化n个统计量,初始值为0.0"""self.data = [0.0] * ndef add(self, *args):"""将传入的数值累加到对应的统计量中"""# args的长度应等于n,每个参数对应一个统计量self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):"""重置所有统计量为0.0"""self.data = [0.0] * len(self.data)def __getitem__(self, idx):"""通过索引获取统计量的当前值"""return self.data[idx]def evaluate_accuracy(data_iter, net):"""评估模型在指定数据集上的准确率参数:data_iter (DataLoader): 数据迭代器(测试集或验证集)net (nn.Module): 待评估的神经网络模型返回:float: 模型在该数据集上的准确率"""# 如果是PyTorch模块,切换到评估模式(关闭Dropout、BatchNorm的统计更新等)if isinstance(net, torch.nn.Module):net.eval()# 初始化累加器(累积正确数和总样本数)metric = Accumulator(2)  # 索引0: 正确预测数,索引1: 总样本数# 遍历数据集的每个批量for X, y in data_iter:# 模型前向传播得到预测结果y_hat = net(X)# 计算当前批量的准确率(正确预测数)correct = accuracy(y_hat, y)# 累加正确数和总样本数(y.numel()获取当前批量的样本总数)metric.add(correct, y.numel())# 准确率 = 正确预测数 / 总样本数return metric[0] / metric[1]def train_epoch_ch3(net, train_iter, loss, trainer):"""训练模型一个epoch(遍历整个训练集一次)参数:net (nn.Module): 待训练的神经网络train_iter (DataLoader): 训练数据迭代器loss (nn.Module): 损失函数(如CrossEntropyLoss)trainer (torch.optim.Optimizer): 优化器(如SGD)返回:tuple: (训练损失, 训练准确率)"""# 切换到训练模式(启用Dropout、BatchNorm的训练行为等)if isinstance(net, torch.nn.Module):net.train()# 初始化累加器(累积损失总和、正确预测数、总样本数)metric = Accumulator(3)  # 索引0: 损失总和,索引1: 正确数,索引2: 总样本数# 遍历训练集的每个批量for X, y in train_iter:# 前向传播:计算预测结果y_hat = net(X)# 计算当前批量的损失值l = loss(y_hat, y)# 反向传播并更新参数(根据优化器类型处理)if isinstance(trainer, torch.optim.Optimizer):# 标准优化器流程(如SGD、Adam)trainer.zero_grad()  # 清空梯度缓存l.backward()  # 反向传播计算梯度trainer.step()  # 优化器更新参数# 累加当前批量的损失(乘以批量大小,因为l是平均损失)、正确数、总样本数metric.add(float(l) * len(y),  # 总损失 = 平均损失 * 样本数accuracy(y_hat, y),  # 当前批量的正确数y.size().numel()  # 当前批量的样本总数(等价于len(y)))else:# 兼容自定义优化器(如手动实现SGD)l.sum().backward()  # 对损失总和反向传播(l可能是未平均的损失)trainer.step()  # 更新参数# 累加总损失(l.sum()是当前批量的总损失)、正确数、总样本数metric.add(float(l.sum()),  # 当前批量的总损失accuracy(y_hat, y),  # 当前批量的正确数y.numel()  # 当前批量的样本总数)# 计算平均训练损失和准确率# 损失:总损失 / 总样本数# 准确率:正确数 / 总样本数return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer):"""训练模型多个epoch,并在每个epoch后评估测试集准确率参数:net (nn.Module): 待训练的神经网络train_iter (DataLoader): 训练数据迭代器test_iter (DataLoader): 测试数据迭代器loss (nn.Module): 损失函数num_epochs (int): 训练轮次(遍历整个训练集的次数)trainer (torch.optim.Optimizer): 优化器"""# 遍历每个epochfor epoch in range(num_epochs):# 训练一个epoch,得到训练损失和训练准确率train_metrics = train_epoch_ch3(net, train_iter, loss, trainer)# 评估测试集准确率test_acc = evaluate_accuracy(test_iter, net)# 打印训练进度print(f'epoch {epoch + 1}, loss {train_metrics[0]:.4f}, 'f'train acc {train_metrics[1]:.3f}, test acc {test_acc:.3f}')def predict_ch3(net, test_iter, n=6):"""可视化测试集中前n个样本的真实标签和预测标签参数:net (nn.Module): 训练好的模型test_iter (DataLoader): 测试数据迭代器n (int): 要显示的样本数量(默认6)"""# 获取测试集的一个批量(仅取第一个批量)for X, y in test_iter:break  # 只取第一个批量# 获取前n个样本的真实标签(文本形式)trues = get_fashion_mnist_lables(y[:n])# 获取前n个样本的预测标签(通过argmax得到类别索引,再转文本)preds = get_fashion_mnist_lables(net(X[:n]).argmax(axis=1))# 生成标题(真实标签 + 换行 + 预测标签)titles = [true + '\n' + pred for true, pred in zip(trues, preds)]# 显示前n个样本的图像(展平为28x28的二维数组)show_images(X[:n].reshape((n, 28, 28)), 1, n, titles=titles[:n])# ========= 公共函数 end =========# ---------------------- 主程序 ----------------------
# 超参数设置
batch_size = 256  # 每批样本数量
lr = 0.1  # 学习率
num_inputs = 784  # 输入维度(28x28=784,Fashion-MNIST图像尺寸28x28)
num_outputs = 10  # 输出维度(10个类别)# 定义模型结构并初始化参数
def init_weights(m):"""自定义权重初始化函数(用于全连接层)"""if type(m) == nn.Linear:  # 仅对全连接层(Linear)生效# 正态分布初始化权重(均值0,标准差0.01)nn.init.normal_(m.weight, std=0.01)# 偏置初始化为0(默认行为,可不写)# nn.init.zeros_(m.bias)# 构建顺序模型(Flatten + 全连接层)
net = nn.Sequential(nn.Flatten(),  # 将输入的[批量大小, 1, 28, 28]展平为[批量大小, 784]nn.Linear(num_inputs, num_outputs)  # 全连接层:784 -> 10
)# 应用权重初始化函数到模型的每一层
net.apply(init_weights)# 定义损失函数(交叉熵损失,适用于多分类问题)
loss = nn.CrossEntropyLoss()# 定义优化器(随机梯度下降,学习率lr)
trainer = torch.optim.SGD(net.parameters(), lr=lr)# 训练轮次
num_epochs = 10# 加载Fashion-MNIST数据集(训练集和测试集)
train_iter, test_iter = load_data_fashion_mnist(batch_size)# 开始训练模型
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)# 使用训练好的模型进行预测(可视化结果)
predict_ch3(net, test_iter)

http://www.dtcms.com/a/270875.html

相关文章:

  • 我是如何实现在线客服系统的极致稳定性与安全性的
  • NumPy-广播机制深入理解
  • HashMap的put、get方法详解(附源码)
  • 冷冻电镜重构的GPU加速破局:从Relion到CryoSPARC的并行重构算法
  • 【前端】异步任务风控验证与轮询机制技术方案(通用笔记版)
  • 在Centos系统上如何有效删除文件和目录的指令汇总
  • 【C++ 】第二章——类(Class)学习笔记
  • SpringGateway网关增加https证书验证
  • 基于YOLO的足球检测Web应用:从训练到部署的完整实战
  • 《心灵沟通小平台,创新发展大未来》
  • brainstorm MEG处理流程
  • 2024 睿抗编程技能赛——省赛真题解析(含C++源码)
  • 图像匹配方向最新论文--CoMatch: Covisibility-Aware Transformer for Subpixel Matching
  • 【QT】文件、多线程、网络相关内容
  • 【基础算法】贪心 (四) :区间问题
  • spring-data-jpa + Alibaba Druid多数据源案例
  • (5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
  • OpenCV图片操作100例:从入门到精通指南(4)
  • [C#/.NET] 内网开发中如何使用 System.Text.Json 实现 JSON 解析(无需 NuGet)
  • 树莓派vsftpd文件传输服务器的配置方法
  • Java 大视界 -- 基于 Java 的大数据分布式计算在生物信息学蛋白质 - 蛋白质相互作用预测中的应用(340)
  • 【算法深练】DFS题型拆解:沿着路径“深挖到底”、递归深入、回溯回探的算法解题思路
  • 【数据分析】多数据集网络分析:探索健康与退休研究中的变量关系
  • ESOP系统电子作业指导汽车零部件车间的数字化革命
  • 玛哈特网板矫平机:精密矫平金属开平板的利器
  • 钉钉企业应用开发技巧:查询表单实例数据新版SDK指南
  • 2023年华为杯研究生数学建模竞赛A题WLAN组网分析
  • 结构体指针:使用结构体指针访问和修改结构体成员。
  • 【网络】Linux 内核优化实战 - net.ipv4.tcp_ecn_fallback
  • softmax