当前位置: 首页 > news >正文

pytorch-利用letnet5框架深度学习手写数字识别

LetNet-5

利用letnet5框架深度学习手写数字识别
LeNet-5 项目说明
项目简介

本项目实现了经典的 LeNet-5 卷积神经网络模型,主要用于手写数字识别任务。模型结构包括两个卷积层、两个池化层和三个全连接层,适用于 MNIST 数据集。

项目结构
.
├── model.py # LeNet-5 模型定义
├── plot.py # 数据加载与可视化
├── train.py # 模型训练脚本
├── test.py # 模型测试与可视化
├── best_model.pth # 训练后的最佳模型权重
├── README.md # 项目说明文档

安装依赖

pip install torch torchvision matplotlib

数据加载与预处理

在 plot.py 中,定义了 test_Loader,用于加载 MNIST 测试数据集。数据预处理包括:

将图像转换为 Tensor

标准化图像数据

加载器使用 DataLoader 进行批处理

模型定义

在 model.py 中,定义了 LeNet-5 模型结构。模型包括以下层:

输入层:32x32 灰度图像

C1:卷积层,6 个 5x5 卷积核,输出 28x28 特征图

S2:池化层,2x2 平均池化,输出 14x14 特征图

C3:卷积层,16 个 5x5 卷积核,输出 10x10 特征图

S4:池化层,2x2 平均池化,输出 5x5 特征图

C5:卷积层,120 个 1x1 卷积核,输出 1x1 特征图

F6:全连接层,84 个神经元

输出层:10 个神经元,对应 10 个数字类别

模型训练

在 train.py 中,定义了模型训练过程,包括:

加载训练数据

定义损失函数和优化器

训练模型并保存最佳权重至 best_model.pth

模型测试与可视化

在 test.py 中,定义了模型测试过程:

加载测试数据

加载训练好的模型权重

计算测试准确率

可视化预测结果:

import torch
import matplotlib.pyplot as plt
import modeldef test_model_process(model, test_data, max_visualize=10):test_acc = 0.0test_num = 0visualize_count = 0  # 可视化计数model.eval()with torch.no_grad():for test_x, test_y in test_data:output = model(test_x)pre_label = torch.argmax(output, dim=1)test_acc += torch.sum(pre_label == test_y)test_num += test_x.size(0)# 遍历 batchfor i in range(test_x.size(0)):if visualize_count >= max_visualize:breaklabel = test_y[i].item()result = pre_label[i].item()# 可视化img = test_x[i].squeeze().cpu()  # 去掉 channelplt.imshow(img, cmap='gray')title_color = 'green' if label == result else 'red'plt.title(f"预测值:{result} 真实值:{label}", color=title_color)plt.axis('off')plt.show()# 控制台输出if label == result:print("预测值:", result, "-------", "真实值", label)else:print("预测值:", result, "-----------------------", "真实值", label)visualize_count += 1test_avg_acc = test_acc.item() / test_numprint("测试准确率:", test_avg_acc)

使用方法

训练模型:

python train.py

测试模型并可视化:

python test.py

资源连接链接🔗

http://www.dtcms.com/a/351592.html

相关文章:

  • 漫谈《数字图像处理》之霍夫变换发展历程与演进脉络
  • 类似ant design和element ui的八大Vue的UI框架详解优雅草卓伊凡
  • (vue)el-progress左侧添加标签/名称
  • C++学习(4)模板与STL
  • 虚幻5引擎:我们是在创造世界,还是重新发现世界?
  • 8.26 review
  • 【大前端】React统计所有网络请求的成功率、失败率以及统一入口处理失败页面
  • Ubuntu22.04安装OBS
  • 嵌入式系统学习Day23(进程)
  • 2025.8.26总结
  • 【系统架构设计(二)】系统工程与信息系统基础中:信息系统基础
  • 数据结构青铜到王者第四话---LinkedList与链表(1)
  • 【SystemUI】新增实体键盘快捷键说明
  • 【SystemUI】锁屏点击通知显示的解锁界面和通知重叠
  • [Sync_ai_vid] 唇形同步推理流程 | Whisper架构
  • 技术分享︱国产化突破:开源MDO工具链在新一代神威超算上的安装与调试
  • DevExpress WinForms中文教程:Data Grid - Excel样式的自定义过滤器对话框
  • 在Excel和WPS表格中输入分数的两种方法
  • 自然处理语言NLP: 基于双分支 LSTM 的酒店评论情感分析模型构建与实现
  • PostgreSQL快速入门
  • 会议室预约小程序主要功能及预约审批流程
  • Java大厂面试全解析:从Spring Boot到微服务架构实战
  • Hadoop MapReduce 任务/输入数据 分片 InputSplit 解析
  • ProfiNet转CAN/CANopen网关技术详解-三格电子
  • uniapp uview吸顶u-sticky 无效怎么办?
  • 利用Certbot生成ssl证书配置到nginx
  • Android之穿山甲广告接入
  • Flutter 项目命名规范 提升开发效率
  • 深度学习(三):PyTorch 损失函数:按任务分类的实用指南
  • Swift 解法详解 LeetCode 363:矩形区域不超过 K 的最大数值和