当前位置: 首页 > news >正文

Pytorch-02数据集和数据加载器的基本原理和基本操作

1. 为什么要有数据集类和数据加载器类?

一万个人会有一万种获取并处理原始数据样本的代码,这会导致对数据的操作代码标准不一,并且很难复用。
在这里插入图片描述

为了解决这个问题,Pytorch提供了两种最基本的数据相关类:

  • torch.utils.data.Dataset: 一个数据集对象,包含每个数据样本路径以及对应标签
  • torch.utils.data.DataLoader:持有一个对Dataloader的迭代器,通过调用Dataset__getitem__函数方便地获取实际的样本-标签对

PyTorch 为不同的任务类型提供了方便的预加载数据集,例如 torchvision.datasets、torchaudio.datasets 等。这些数据集都是 torch.utils.data.Dataset 的子类,可以直接通过dataset.数据集名称的方式来方便的下载经典的数据集,在下面你会看到它的使用例。

2. Dataset类的使用方法

2.1 加载一个Fashion-MNIST数据集

Fashion-MNIST 是一个来自 Zalando 的文章图像数据集,包含 60,000 个训练样本和 10,000 个测试样本。每个样本由一张 28×28 的灰度图像和其对应的 10 个类别中的一个标签组成。

这是一个使用TorchVision预加载数据集类加载Fashion-MNIST 数据集的例子,如下是每个参数代表的意思:

  • root:是存储训练/测试数据的路径。
  • train:指定是训练数据集还是测试数据集。
  • download=True:如果数据在 root 路径下不可用,则从互联网下载。
  • transform 和 target_transform:分别指定特征和标签的转换。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data", # 指定数据集实际存放的路径(相对于本代码文件)train=True, # 指定这是训练集还是测试集download=True, # 如果在root下没有数据,从网络上自动下载transform=ToTensor() # 给每一张图片转换为Tensor的数据类型
)test_data = datasets.FashionMNIST(root="data", # 指定数据集实际存放的路径(相对于本代码文件)train=False, # 指定这是训练集还是测试集download=True, # 如果在root下没有数据,从网络上自动下载transform=ToTensor() # 给每一张图片转换为Tensor的数据类型
)

在这里插入图片描述

2.2 遍历并可视化数据集

我们可以简单的使用training_data[index]来获取Datasets类中对应index的样本。通常可以用matplotlib来可视化我们的一些训练数据集:

labels_map = { # 定义一个标签映射字典0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}figure = plt.figure(figsize=(8, 8)) # 创建一个新的画布,大小为8x8英寸
cols, rows = 3, 3 # 定义展示网格尺寸 3x3的展示网格,每个网格展示i一个图片for i in range(1, cols * rows + 1): # plt的索引从1开始,配合一下sample_idx = torch.randint(len(training_data), size=(1,)).item() # 生成一个包含1个元素的张量,item()回python数据类型之后为0到数据集大小-1的随机整数img, label = training_data[sample_idx] # 本质上是在调用__getitem__函数figure.add_subplot(rows, cols, i) # 在之前创建的图形窗口中,添加一个子图(subplot),并将当前的画笔操作对象设置为当前子图plt.title(labels_map[label]) # 子图的标题设置为对应的标签字符串plt.axis("off") # 不显示坐标轴plt.imshow(img.squeeze(), cmap="gray") # 把当前网格画好
plt.show() # 展示画布

这里我并不知道为啥要使用img.squeeze()这个方法, 直到我把img的shape的打印出来:
在这里插入图片描述
现在img是一个3维的tensor,但是plt.imshow需要输入二维的tensor,所以使用squeeze的目的是把所有的尺寸为1的维度给挤压掉,将img维度降维到2维,然后就可以用plt可视化了。

在这里插入图片描述

2.3 进阶:如何制作一个自己的数据集类

自定义的 Dataset 类必须实现三个函数:__init____len____getitem__。请看下面的实现示例:FashionMNIST 图像存储在 img_dir 目录中,而它们的标签则单独保存在 annotations_file 的 CSV 文件里。

import os
import pandas as pd
from torchvision.io import decode_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitemm__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # iloc全写为“integer location”, 表明你要通过数据的行和列的整数索引来选择数据image = decode_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

在接下来的部分将详细解释每个方法的作用。

__init__

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

这个方法会在初始化数据集的时候调用。其主要完成如下工作:

  1. 读取标签文件
  2. 指定图片文件夹路径
  3. 指定样本和标签的transform(这个下面细讲)

一个Fashion-MNIST是一个分类任务,其标签文件annotations大概长这样:

tshirt1.jpg, 0 # 样本-标签对
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

__len__

这个方法是简单返回数据集的样本数量:

def __len__(self):return len(self.img_labels)

__getitem__

这个方法是Dataset类的核心,当此方法被Dataloader调用,请求特定idx的数据时,Dataset会根据idx,读取对应的图片和标签,并对它们做出各自的transform之后,返回给Dataloader,让它把图片和标签搬运到内存.

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

3. Dataloader类的使用方法

3.1 对数据集对象配置Dataloader

Dataset类的__getitem__方法被调用的时候,他会返回一个样本-标签对。

但是在实际的模型训练中,我们还有一些别的要求,例如:

  1. 以“小批量(minibatches)”的方式传递样本。(减少单样本噪声带来的震荡,让梯度更新的方向更加稳定)
  2. 在每个周期(epoch)对数据进行重新洗牌(reshuffle),以减少模型过拟合。
  3. 使用 Python 的多进程(multiprocessing)来加快数据检索速度。

以上的要求可以通过如下的参数设定来满足:

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, num_workers=5)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=5)
  • batch_size=64 设定批量大小为64
  • shuffle=True 指定一个epoch之后dataloader持有的索引要重新洗牌
  • num_workers=5 指定dataloader会同时开启5个进程去调用dataset的__getitem__方法

以上是Dataloader最基本的用法,不过,当你有GPU的时候,我推荐你也把下面两个参数打开:
pin_memory=True 开启锁页内存,减少CPU到GPU的数据传递延迟
persistent_workers=True 每个epoch结束后不销毁dataloader所开启的worker进程,而是接着用,这样剩下了worker的初始化时间

3.2 使用Dataloader遍历数据集

给Dataset配置好对应的Dataloader后,就可以开始用dataloader遍历它了。每次遍历都会返回一个batch_size的训练图片和训练标签对(这里就是64个)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 先从train_dataloader中获得一个迭代器,然后调用next获取其下一个元素
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述

由于开启了shuffle=True,所以每次遍历完整个数据集后train_dataloader持有的索引会被打乱。

http://www.dtcms.com/a/311053.html

相关文章:

  • Unity_数据持久化_XML基础
  • 【C++】第二十一节—一文详解 | 红黑树实现(规则+效率+结构+插入+查找+验证)
  • 福彩双色球第2025088期篮球号码分析
  • 电脑手机热点方式通信(上)
  • StarRocks vs ClickHouse:2025 年 OLAP 引擎终极对比指南
  • Day25-对称二叉树-
  • 仿真电路:(十七下)DC-DC升压压电路原理简单仿真
  • Clickhouse#记录隐藏字段
  • 综合:单臂路由+三层交换技术+telnet配置+DHCP
  • 【云计算】云主机的亲和性策略(四):云主机组
  • C 语言问题
  • 【机器学习】两大线性分类算法:逻辑回归与线性判别分析:找到分界线的艺术
  • 复杂路况下漏检率↓78%!陌讯动态决策模型在井盖缺失检测的实战解析
  • 系统性学习数据结构-第一讲-算法复杂度
  • Agents-SDK智能体开发[5]之集成MCP进阶
  • 机器学习 —— 决策树
  • [硬件电路-114]:模拟电路 - 信号处理电路 - 放大器的种类与比较
  • Node.js 路由与中间件
  • [硬件电路-119]:模拟电路 - 信号处理电路 - 比较器,模拟电路中的“决策者”,模拟信号到数字电平逻辑信号的转化者...
  • 音视频学习(四十六):声音的三要素
  • 小迪23-28~31-js简单回顾
  • K8S的Pod之initC容器restartPolicy新特性
  • 强光干扰下识别精度提升28%!陌讯多模态融合算法在油罐车识别的落地实践
  • ubuntu源码编译安装cmake高版本、pybind11安装、crow使用
  • 第3章栈、队列、数组和矩阵
  • 译|Netflix 技术博客:一个利用视觉-语言模型和主动学习高效构建视频分类器的框架
  • 什么叫湖仓一体
  • 一个物理引擎仿真器(mujoco这种)的计算流程
  • ubuntu 系统风扇控制软件 CoolerControl
  • 烽火HG680-KD_海思MV320处理器-安卓9-原厂系统升级包-针对解决烧录不进系统的问题