当前位置: 首页 > news >正文

使用PyTorch实现LeNet-5并在Fashion-MNIST数据集上训练

本文将展示如何使用PyTorch实现经典的LeNet-5卷积神经网络,并在Fashion-MNIST数据集上进行训练和评估。代码包含完整的网络定义、数据加载、训练流程及结果可视化。

1. 导入依赖库

import torch
from torch import nn
from d2l import torch as d2l

2. 定义LeNet-5网络结构

通过PyTorch的nn.Sequential构建网络,包含卷积层、池化层和全连接层:

class Reshape(nn.Module):
    def forward(self, x):
        return x.view(-1, 1, 28, 28)  # 将输入重塑为1x28x28

net = nn.Sequential(
    Reshape(),
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16*5*5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

3. 验证各层输出形状

输入随机数据检查网络各层的输出形状:

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(f"{layer.__class__.__name__}输出形状:\t{X.shape}")

输出结果:

Reshape output shape:     torch.Size([1, 1, 28, 28])
Conv2d output shape:      torch.Size([1, 6, 28, 28])
...
Linear output shape:      torch.Size([1, 10])

4. 加载Fashion-MNIST数据集

使用d2l库快速加载数据,设置批量大小为256:

batch_size = 256
train_data, test_data = d2l.load_data_fashion_mnist(batch_size)

5. 定义评估函数

修改后的准确率评估函数支持GPU计算:

def evaluate_accuracy(net, data, device=None):
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    metric = d2l.Accumulator(2)
    for X, y in data:
        if isinstance(X, list):
            X = [x.to(device) for x in X]
        else:
            X = X.to(device)
        y = y.to(device)
        metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

6. 训练与评估模型

调用d2l.train_ch6进行训练,设置10个周期和学习率0.9:

lr, num_epochs = 0.9, 10
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, d2l.try_gpu())

输出结果:

loss 0.470, train acc 0.822, test acc 0.805
80458.2 examples/sec on cuda:0

7. 训练结果可视化

训练过程中会自动生成损失和准确率曲线:

http://www.dtcms.com/a/106644.html

相关文章:

  • 【Linux】内核驱动学习笔记(二)
  • 基于Spring AI与Ollama构建本地DeepSeek对话机器人
  • 数据库分库分表中间件及对比
  • ensp 网络模拟器 思科华为基于VLANIF的公司网络搭建
  • 2025.4.2总结
  • Go语言GC:三色标记法工程启示|Go语言进阶(3)
  • K-means算法
  • 从零搭建微服务项目Pro(第7-1章——分布式雪花算法)
  • cmake(11):list 选项 排序 SORT,定义宏 add_definitions,cmake 里预定义的 8 个宏
  • Git 命令大全:通俗易懂的指南
  • 基于大模型预测风湿性心脏病二尖瓣病变的多维度诊疗研究报告
  • 内网隔离环境下Java实现图片预览的三大解决方案
  • 【Django开发】前后端分离django美多商城项目第15篇:商品搜索,1. Haystack介绍和安装配置【附代码文档】
  • 从 ZStack 获取物理机与云主机信息并导出 Excel 文件
  • visual studio 2022的windows驱动开发
  • C# System.Text.Json 中 JsonIgnoreCondition 使用详解
  • Linux2 CD LL hostnamectl type mkdir dudo
  • 跨系统平台实践:在内网自建kylin服务版系统yum源
  • 面基JavaEE银行金融业务逻辑层处理金融数据类型BigDecimal
  • AI提示词:好评生成器
  • 鸿蒙NEXT小游戏开发:数字华容道
  • 详解相机的内参和外参,以及内外参的标定方法
  • 背包DP总结
  • GO语言 使用protobuf
  • 【第十三届“泰迪杯”数据挖掘挑战赛】【2025泰迪杯】【代码篇】A题解题全流程(持续更新)
  • 全国产ADC 16bit 2通道1G采样 双FMC子板
  • C++多继承
  • 【抓包工具】win 10 / win 11:Charles 下载、安装、配置(快捷方式、默认端口、登录、https 证书)
  • 【git】VScode修改撤回文件总是出现.lh文件,在 ​所有 Git 项目 中全局忽略特定文件
  • MacOS 的 AI Agent 新星,本地沙盒驱动,解锁 macOS 操作新体验!