当前位置: 首页 > news >正文

Pytorch下载Mnist手写数据识别训练数据集的代码详解

datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

1. datasets.MNIST

这是torchvision.datasets模块中的一个类,专门用于加载MNIST数据集。MNIST是一个著名的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本,每个样本是28x28的灰度图像。

2. 参数解释

root='./data'
  • 作用:指定数据集下载和存储的根目录。

  • 解释:这里设置为当前目录下的data文件夹。如果该文件夹不存在,PyTorch会自动创建它。

  • 默认值:通常没有默认值,必须指定。

train=False
  • 作用:指定加载的是训练集还是测试集。

  • 解释

    • train=True:加载训练集(60,000个样本)。

    • train=False:加载测试集(10,000个样本)。

  • 默认值:通常是True(加载训练集),但明确指定是好的实践。

download=True
  • 作用:控制是否下载数据集。

  • 解释

    • download=True:如果数据集在root目录下不存在,则自动下载。

    • download=False:不下载,仅尝试从root目录加载。

  • 默认值:通常是False,但这里显式设置为True以确保下载。

transform=transforms.ToTensor()
  • 作用:指定对加载的数据进行何种预处理或转换。

  • 解释

    • transforms.ToTensor()是PyTorch的一个转换函数,它将PIL图像或NumPy数组转换为PyTorch张量(torch.Tensor),并自动进行以下操作:

      1. 将图像的像素值从[0, 255]缩放到[0.0, 1.0](除以255)。

      2. 将图像的形状从(H, W, C)(高度、宽度、通道)转换为(C, H, W)(通道、高度、宽度)。对于MNIST,因为是灰度图像,所以通道数为1,形状从(28, 28)变为(1, 28, 28)

    • 如果不指定transform,返回的是PIL图像格式。

  • 默认值:如果不指定,返回原始数据(通常是PIL图像)。

3. 返回值

这行代码的返回值是一个torchvision.datasets.MNIST对象,可以像数据集一样使用:

  • 可以通过索引(如dataset[0])访问单个样本。

  • 可以通过len(dataset)获取数据集大小。

  • 通常用于DataLoader中批量加载数据。

4. 完整示例

python

from torchvision import datasets, transforms# 下载并加载MNIST测试集
test_dataset = datasets.MNIST(root='./data',      # 数据存储目录train=False,        # 加载测试集download=True,      # 如果数据不存在则下载transform=transforms.ToTensor()  # 转换为张量并归一化到[0,1]
)# 打印数据集大小
print(len(test_dataset))  # 输出: 10000# 获取第一个样本
image, label = test_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 28, 28])
print(label)        # 输出: 7(标签)

5. 其他常见参数(虽然不是这里使用的)

  • target_transform:对标签(target)进行转换的函数(类似transform对图像的作用)。

  • 某些数据集可能有额外参数(如MNIST通常没有,但其他数据集可能有split等)。

总结:这行代码的作用是从PyTorch自动下载MNIST测试集到./data目录,并将图像转换为PyTorch张量格式,方便后续用于深度学习模型的测试。

http://www.dtcms.com/a/285591.html

相关文章:

  • PyTorch新手实操 安装
  • 填坑 | React Context原理
  • SpringMVC + Tomcat10
  • 小结:Spring MVC 的 XML 的经典配置方式
  • 计算机视觉与机器视觉
  • Tensorflow小白安装教程(包含GPU版本和CPU版本)
  • C++并发编程-13. 无锁并发队列
  • div和span区别
  • 【Python】python 爬取某站视频批量下载
  • 前端实现 web获取麦克风权限 录制音频 (需求:ai对话问答)
  • 20250718【顺着234回文链表做两题反转】Leetcodehot100之20692【直接过12明天吧】今天计划
  • AugmentCode还没对个人开放?
  • STL—— list迭代器封装的底层讲解
  • 71 模块编程之新增一个字符设备
  • Proto文件从入门到精通——现代分布式系统通信的基石(含实战案例)
  • 标题 “Python 网络爬虫 —— selenium库驱动浏览器
  • 光伏电站工业通信网络解决方案:高可靠工业通信架构与设备选型
  • 开源短链接工具 Sink 无需服务器 轻松部署到 Workers / Pages
  • 西门子工业软件全球高级副总裁兼大中华区董事总经理梁乃明先生一行到访庭田科技
  • ArcGIS Pro+PS 实现地形渲染效果图
  • WinDbg命令
  • FastAdmin框架超级管理员密码重置与常规admin安全机制解析-卓伊凡|大东家
  • 本地部署DeepSeek-R1并打通知识库
  • 数字地与模拟地隔离
  • 【C语言】深入理解柔性数组:特点、使用与优势分析
  • Cursor替代,公测期间免费使用Claude4
  • 首个直播流扩散(LSD)AI模型:MirageLSD,它可以实时把任意视频流转换成你的自定义服装风格——虚拟换装新体验
  • mpiigaze的安装过程一
  • 【后端】.NET Core API框架搭建(10) --配置163邮件发送服务
  • 【锂电池剩余寿命预测】TCN时间卷积神经网络锂电池剩余寿命预测(Pytorch完整源码和数据)