当前位置: 首页 > news >正文

PyTorch2 Python深度学习 - 卷积神经网络(CNN)介绍实例 - 使用MNIST识别手写数字示例

锋哥原创的PyTorch2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1eqxNzXEYc

课程介绍


​基于前面的机器学习Scikit-learn,深度学习Tensorflow2课程,我们继续讲解深度学习PyTorch2,所以有些机器学习,深度学习基本概念就不再重复讲解,大家务必学习好前面两个课程。本课程主要讲解基于PyTorch2的深度学习核心知识,主要讲解包括PyTorch2框架入门知识,环境搭建,张量,自动微分,数据加载与预处理,模型训练与优化,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

PyTorch2 Python深度学习 - 卷积神经网络(CNN)介绍实例 -  使用MNIST识别手写数字示例

MNIST(Modified National Institute of Standards and Technology)数据集是一个常用于机器学习和深度学习领域的经典数据集,特别是在图像识别任务中。它由美国国家标准与技术研究院(NIST)提供,广泛用于手写数字识别的研究和算法测试。

主要特点:

  1. 数据内容:

    • MNIST数据集包含了28x28像素的灰度图像,表示从0到9的手写数字。每个图像展示了一个单一的手写数字(0到9之一)。

    • 数据集分为两个部分:

      • 训练集:包含60,000个样本,用于训练模型。

      • 测试集:包含10,000个样本,用于测试和评估模型的性能。

  2. 标签信息:

    • 每个图像都有一个对应的标签,表示图像中手写数字的真实值(即0到9之间的某个数字)。

  3. 数据预处理:

    • 图像的大小是28x28像素,灰度级别为0到255,其中0表示白色,255表示黑色。图像通常在输入神经网络之前会被标准化或者归一化。

  4. 应用领域:

    • 手写数字识别:这是MINIST数据集的经典应用,用于测试各种机器学习算法的性能。

    • 分类问题:可以用于对比不同模型(如支持向量机、神经网络、决策树等)的分类准确性。

下面是具体示例:

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
​
# 1,数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),  # 将图片转换为Tensortransforms.Normalize((0.5,), (0.5,))  # 数据归一化
])
​
trainset = datasets.MNIST(root='data',train=True,download=True,transform=transform
)
testset = datasets.MNIST(root='data',train=False,download=True,transform=transform
)
​
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
​
# 2,定义模型
model = nn.Sequential(# 第一层卷积层,输入1通道,输出32通道,卷积核大小3x3,填充1,nn.Conv2d(1, 32, kernel_size=3, padding=1),nn.ReLU(),  # 激活函数ReLUnn.MaxPool2d(2, 2),  # 池化层,池化核大小2x2,步长2
​# 第二层卷积层,输入32通道,输出64通道,卷积核大小3x3,填充1,nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),  # 激活函数ReLUnn.MaxPool2d(2, 2),  # 池化层,池化核大小2x2,步长2
​# 展平操作,将数据从二维转为一维nn.Flatten(),
​# 第一个全连接层,输入64*7*7,输出128nn.Linear(64 * 7 * 7, 128),nn.ReLU(),
​# 第二个全连接层,输出10个分类(数字0-9)nn.Linear(128, 10)
)
​
# 3, 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器,Adam优化算法
​
# 4,训练模型
epochs = 5
model.train()  # 训练模式
for epoch in range(epochs):for images, labels in trainloader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)  # 计算损失
​# 反向传播和优化optimizer.zero_grad()  # 清空梯度loss.backward()  # 反向传播optimizer.step()  # 更新参数
​print(f'Epoch [{epoch + 1}/{epochs}],  Loss: {loss.item():.4f}')
​
# 5, 测试模型
model.eval()  # 测试模式
correct = 0
total = 0
with torch.no_grad():  # 禁用梯度计算for images, labels in testloader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()
print(f'Accuracy: {100 * correct / total:.2f}%')

运行输出:

http://www.dtcms.com/a/562683.html

相关文章:

  • 做一个这样的网站应该报价多少齐河县城乡建设局网站
  • phpmysql网站模板江苏中星建设集团网站
  • 网站开发配置状态报告wordpress免费版
  • SQL练习平台推荐:从入门到精通的学习路径
  • 手机网站开发 html5百度网盘可以做网站吗?
  • 手机网站模板 优帮云wordpress简易商城
  • 做封面下载网站做网站v1认证需要付费吗
  • 深圳上市公司网站建设公司佛山做网站优化公司
  • 2025年11月2日 AI大事件
  • 靖江做网站的单位购物网站开发的必要性
  • 淘宝客免费网站建设yahoo怎么提交网站
  • 学校网站建设宗旨临沂做网站公司
  • 期货数据实时展示前端实现方案K线图表展示
  • 网站项目建设的必要性郑州做网站优化的公司
  • dedecms 我的网站wordpress产品参数
  • 网站建设需求调查表做公司网站怎么推广
  • 个人网站服务器一年多少钱站长工具seo综合查询怎么去掉
  • 用模板做网站会被盗吗南通建设信息网站
  • 怎么开个人网站赚钱怎么在导航网站上做推广
  • 建设部网站官网证书编号吴江和城乡建设局网站
  • 网站建设需要的费用重庆建设工程信息网30系统
  • 域名注册最后是网站求职网站怎么做
  • 17.如何利用ArcGIS进行空间统计分析
  • 建设门户网站的请示小红书网络营销方式
  • 外贸网站营销推广鑫诺科技网站建设
  • 郑州网站建设模板换网站公司
  • 设计素材网站会员怎么买划算泉州公司做网站
  • 零基础学JAVA--Day21(房屋出租系统+韩顺平Utility类原码)
  • 广东手机网站制作电话平面设计师的出路
  • 京东网站优化广州注册公司有什么优惠政策