当前位置: 首页 > news >正文

day38 python Dataset和Dataloader

目录

一、背景知识

二、数据预处理与数据集加载

三、Dataset类:定义“数据是什么”和“如何获取单个样本”

1. __getitem__方法详解

2. __len__方法详解

3. 自定义MNIST数据集类

4. 可视化原始图像

四、DataLoader类:定义“如何批量加载数据”和“加载策略”

五、总结

1. Dataset类的核心要点

2. DataLoader类的核心要点

3. 两者的协同工作


一、背景知识

MNIST数据集是一个非常经典的数据集,包含60000张训练图片和10000张测试图片,每张图片大小为28×28像素,共包含10个类别(0到9的数字)。由于每个数据的维度比较小,既可以视为结构化数据,用机器学习、MLP(多层感知机)训练,也可以视为图像数据,用卷积神经网络训练。

在处理大规模数据集时,显存常常无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch的DataLoader类可以自动将数据集切分为多个批次(batch),并支持多线程加载数据,从而提高数据加载效率。而Dataset类则用于定义数据集的读取方式和预处理方式。

二、数据预处理与数据集加载

在开始之前,我们需要对数据进行预处理。PyTorch的transforms模块提供了一系列常用的图像预处理操作。以下是我们的预处理流程:

transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])

接下来,我们加载MNIST数据集。如果没有下载过,datasets.MNIST会自动下载:

train_dataset = datasets.MNIST(root='./data',  # 数据存储路径train=True,  # 加载训练集download=True,  # 如果没有数据则自动下载transform=transform  # 应用预处理
)test_dataset = datasets.MNIST(root='./data',  # 数据存储路径train=False,  # 加载测试集transform=transform  # 应用预处理
)

这里需要注意的是,PyTorch的思路是在数据加载阶段就完成数据的预处理,这与我们通常的“先有数据集,后续再处理”的思路有所不同。

三、Dataset类:定义“数据是什么”和“如何获取单个样本”

torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:__len____getitem__

  • __len__方法:返回数据集的样本总数。

  • __getitem__方法:根据索引idx返回对应样本的数据和标签。

这两个方法是PyTorch对数据集的基本要求,只有实现了它们,数据集才能被DataLoader等工具兼容。这类似于一种接口约定,就像函数参数的规范一样。

在Python中,__getitem____len__是类的特殊方法(也叫魔术方法),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,从而赋予类特定的行为。

1. __getitem__方法详解

__getitem__方法用于让对象支持索引操作。当使用[]语法访问对象元素时,Python会自动调用该方法。例如:

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]my_list_obj = MyList()
print(my_list_obj[2])  # 输出:30

通过定义__getitem__方法,MyList类的实例能够像Python内置的列表一样使用索引获取元素。

2. __len__方法详解

__len__方法用于返回对象中元素的数量。当使用内置函数len()作用于对象时,Python会自动调用该方法。例如:

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)my_list_obj = MyList()
print(len(my_list_obj))  # 输出:5

这里定义的__len__方法,使得MyList类的实例可以像普通列表一样被len()函数调用获取长度。

3. 自定义MNIST数据集类

为了更好地理解Dataset类的使用,我们来实现一个简化版本的MNIST数据集类:

class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加载图片路径和标签self.data, self.targets = fetch_mnist_data(root, train)  # 假设 fetch_mnist_data 是一个函数self.transform = transform  # 预处理操作def __len__(self):return len(self.data)  # 返回样本总数def __getitem__(self, idx):# 获取指定索引的图像和标签img, target = self.data[idx], self.targets[idx]# 应用图像预处理if self.transform is not None:img = self.transform(img)return img, target  # 返回处理后的图像和标签

在这个类中,__getitem__方法负责根据索引获取单个样本,并应用预处理操作(如ToTensorNormalize)。这就好比厨师在准备单个菜品时,会进行切菜、调味等预处理操作。

4. 可视化原始图像

为了查看数据集中的图像,我们可以定义一个可视化函数imshow,并随机选择一张图片进行展示:

def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')  # 显示灰度图像plt.show()sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()  # 随机选择一张图片的索引
image, label = train_dataset[sample_idx]  # 获取图片和标签
print(f"Label: {label}")
imshow(image)

四、DataLoader类:定义“如何批量加载数据”和“加载策略”

DataLoader类的职责是将Dataset中的数据批量加载出来,并支持多线程加载,从而提高数据加载效率。它的使用非常简单:

train_loader = DataLoader(train_dataset,batch_size=64,  # 每个批次64张图片shuffle=True  # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000  # 每个批次1000张图片
)

DataLoader类的主要参数包括:

  • dataset:要加载的数据集。

  • batch_size:每个批次的样本数量。

  • shuffle:是否随机打乱数据。

  • num_workers:加载数据时使用的子进程数量,默认为0(不使用多进程)。

DataLoader类可以看作是“服务员”,它将Dataset类准备好的“菜品”(单个样本)按照订单(批量大小、是否打乱等策略)组合并上桌(批量加载)。

五、总结

通过以上内容的学习,我们可以对Dataset类和DataLoader类进行如下的总结:

维度DatasetDataLoader
核心职责定义“数据是什么”和“如何获取单个样本”定义“如何批量加载数据”和“加载策略”
核心方法__getitem__(获取单个样本)、__len__(样本总数)无自定义方法,通过参数控制加载逻辑
预处理位置__getitem__中通过transform执行预处理无预处理逻辑,依赖Dataset返回的预处理后数据
并行处理无(仅单样本处理)支持多进程加载(num_workers>0
典型参数root(数据路径)、transform(预处理)batch_sizeshufflenum_workers

1. Dataset类的核心要点

  • 定义数据的内容和格式:包括数据存储路径/来源、原始数据的读取方式、样本的预处理逻辑以及返回值格式。

  • 实现两个核心方法__len____getitem__,这是PyTorch对数据集的基本要求,也是与DataLoader兼容的关键。

2. DataLoader类的核心要点

  • 定义数据的加载方式和批量处理逻辑:通过batch_size控制每个批次的样本数量,通过shuffle决定是否随机打乱数据,通过num_workers设置多进程加载的子进程数量。

  • 依赖Dataset返回的预处理后数据DataLoader本身不负责预处理,而是直接使用Dataset返回的已经预处理好的数据。

3. 两者的协同工作

  • Dataset类是“厨师”,负责准备单个样本,包括数据的读取和预处理。

  • DataLoader类是“服务员”,负责将“厨师”准备好的单个样本按照订单(批量大小、是否打乱等策略)组合并上桌(批量加载)。

通过Dataset类和DataLoader类的协同工作,我们可以高效地处理和加载大规模数据集,为深度学习模型的训练提供有力支持。

@浙大疏锦行

相关文章:

  • SSM整合:Spring+SpringMVC+MyBatis完美融合实战指南
  • 基于大模型的慢性胃炎全周期预测与诊疗方案研究报告
  • 【Quest开发】空间音频的使用
  • 异常:UnsupportedOperationException: null
  • 【运维_日常报错解决方案_docker系列】一、docker系统不起来
  • OpenCV CUDA模块图像处理------颜色空间处理之用于执行伽马校正(Gamma Correction)函数gammaCorrection()
  • OpenCV CUDA模块图像处理------颜色空间处理之GPU 上对两张带有 Alpha 通道的图像进行合成操作函数alphaComp()
  • 传统数据表设计与Prompt驱动设计的范式对比:以NBA投篮数据表为例
  • 【请关注】VC++ MFC常见异常问题及处理方法
  • 【LeetCode 热题 100】打家劫舍 / 零钱兑换 / 单词拆分 / 乘积最大子数组 / 最长有效括号
  • react基础技术栈
  • [React]实现一个类zustand公共状态库
  • Nginx 性能优化全解析:从进程到安全的深度实践
  • HJ25 数据分类处理【牛客网】
  • 【前端】【React】React性能优化系统总结
  • 嵌入式学习--江协stm32day1
  • 电芯单节精密焊接机:以先进功能与特点赋能电池制造科技升级
  • java-jdk8新特性Stream流
  • 无人机多人协同控制技术解析
  • 武汉火影数字VR大空间制作
  • 北京中国建设工程造价管理协会网站/seo是什么意思怎么解决
  • 中国最好的网站建设公司/南京seo网站管理
  • 怎么利用网站做淘宝客/aso优化服务平台
  • 用v9做的网站上传服务器/四川网络推广seo
  • 网站建设通讯设备中企动力/临沂seo推广外包
  • 在线做图表网站/重庆森林电影高清在线观看