当前位置: 首页 > news >正文

【Pytorch】数据集的加载和处理(一)

 Pytorch torchvision 包提供了很多常用数据集

数据按照用途一般分为三组:训练(train)、验证(validation)和测试(test)。使用训练数据集来训练模型,使用验证数据集跟踪模型在训练期间的性能,使用测试数据集对模型进行最终评估。

目录

导入MNIST训练数据集

提取训练数据和标签

同理操作验证数据集

给张量添加维度

打印示例图像


导入MNIST训练数据集

从 torchvision导入MNIST训练数据集

import torch
import torchvision
from torchvision import datasets
train_data=datasets.MNIST("./data",train=True,download=True)

datasets.MNIST是Pytorch的内置函数

train=True指导入的数据作为训练数据集

download=True若根目录下没有数据集时自动下载

 导入完成后可以看到MINST文件内的数据集

提取训练数据和标签

x_train, y_train=train_data.data,train_data.targets
print(x_train.shape)
print(y_train.shape)

x_train存储60000张28*28的图片,y_train存储60000张图片对应的数字(label)

同理操作验证数据集

从 torchvision导入MNIST验证数据集并提取数据和标签

val_data=datasets.MNIST("./data", train=False, download=True)
x_val,y_val=val_data.data, val_data.targets
print(x_val.shape)
print(y_val.shape)

 

给张量添加维度

Pytorch中张量可以是一维、二维、三维或者更高维度的数据结构。一维张量类似于向量,二维张量类似于矩阵,三维张量类似一系列矩阵的堆叠。添加新的维度可以更好地对数据进行表示和处理。

if len(x_train.shape)==3:x_train=x_train.unsqueeze(1)
print(x_train.shape)if len(x_val.shape)==3:x_val=x_val.unsqueeze(1)
print(x_val.shape)

 .unsqueeze(0)指添加在第一个维度

也可以通过x_train.view(60000,1,28,28)添加维度

可以看到张量由三维变为了四维 

打印示例图像

引入所需的包,定义一个辅助函数,将张量显示为图像

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
def show(img):npimg = img.numpy()npimg_tr=np.transpose(npimg, (1,2,0))plt.imshow(npimg_tr,interpolation='nearest')

创建一个10*10的网格,每行10张图片,pedding=3指间隔为3

x_grid=utils.make_grid(x_train[:100], nrow=10, padding=3)
print(x_grid.shape)
show(x_grid)

utils.make_grid实际上是将多张图片拼接起来,参照官方介绍:

http://www.dtcms.com/a/293381.html

相关文章:

  • 使用ubuntu:20.04和ubuntu:jammy构建secretflow环境
  • ndarray的创建(小白五分钟从入门到精通)
  • 嵌入式开发学习(第三阶段 Linux系统开发)
  • 数据资产——解读数据资产全过程管理手册2025【附全文阅读】
  • [c++11]constexpr
  • 考研数据结构Part1——单链表知识点总结
  • 陷波滤波器设计全解析:原理、传递函数与MATLAB实现
  • Netty中AbstractReferenceCountedByteBuf对AtomicIntegerFieldUpdater的使用
  • 威胁情报:Solana 开源机器人盗币分析
  • Automotive SPICE
  • git的版本冲突
  • 大模型——Data Agent:超越 BI 与 AI 的边界
  • 用ESP32打造全3D打印四驱遥控车:无需APP的Wi-Fi控制方案
  • 从0开始的中后台管理系统-2
  • 课题学习笔记2——中华心法问答系统
  • 汽车行业数字化——解读52页汽车设计制造一体化整车产品生命周期PLM解决方案【附全文阅读】
  • 记录更新时间用java的new date还是数据库的now
  • 深入理解 C 语言数据类型:从内存到应用的全面解析
  • CAN基础知识 - 进阶版
  • 消息推送功能设计指南:精准触达与用户体验的平衡之道
  • Spring Boot 中集成ShardingSphere-JDBC的基本使用
  • Kibana报错[security_exception] current license is non-compliant for [security]
  • HCIA/IP(一二章)笔记
  • TTL+日志的MDC实现简易链路追踪
  • 强化学习理论
  • 计算机是怎么样工作的
  • 在 Ubuntu 22.04 上安装并优化 Nginx nginx入门操作 稍难,需要有一定理论 多理解 多实践
  • Class13预测房价代码
  • Google Gemini 体验
  • 从零开始学CTF(第二十五期)