day38 python Dataset和Dataloader
目录
一、背景知识
二、数据预处理与数据集加载
三、Dataset类:定义“数据是什么”和“如何获取单个样本”
1. __getitem__方法详解
2. __len__方法详解
3. 自定义MNIST数据集类
4. 可视化原始图像
四、DataLoader类:定义“如何批量加载数据”和“加载策略”
五、总结
1. Dataset类的核心要点
2. DataLoader类的核心要点
3. 两者的协同工作
一、背景知识
MNIST数据集是一个非常经典的数据集,包含60000张训练图片和10000张测试图片,每张图片大小为28×28像素,共包含10个类别(0到9的数字)。由于每个数据的维度比较小,既可以视为结构化数据,用机器学习、MLP(多层感知机)训练,也可以视为图像数据,用卷积神经网络训练。
在处理大规模数据集时,显存常常无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch的DataLoader
类可以自动将数据集切分为多个批次(batch),并支持多线程加载数据,从而提高数据加载效率。而Dataset
类则用于定义数据集的读取方式和预处理方式。
二、数据预处理与数据集加载
在开始之前,我们需要对数据进行预处理。PyTorch的transforms
模块提供了一系列常用的图像预处理操作。以下是我们的预处理流程:
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差
])
接下来,我们加载MNIST数据集。如果没有下载过,datasets.MNIST
会自动下载:
train_dataset = datasets.MNIST(root='./data', # 数据存储路径train=True, # 加载训练集download=True, # 如果没有数据则自动下载transform=transform # 应用预处理
)test_dataset = datasets.MNIST(root='./data', # 数据存储路径train=False, # 加载测试集transform=transform # 应用预处理
)
这里需要注意的是,PyTorch的思路是在数据加载阶段就完成数据的预处理,这与我们通常的“先有数据集,后续再处理”的思路有所不同。
三、Dataset类:定义“数据是什么”和“如何获取单个样本”
torch.utils.data.Dataset
是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:__len__
和__getitem__
。
-
__len__
方法:返回数据集的样本总数。 -
__getitem__
方法:根据索引idx
返回对应样本的数据和标签。
这两个方法是PyTorch对数据集的基本要求,只有实现了它们,数据集才能被DataLoader
等工具兼容。这类似于一种接口约定,就像函数参数的规范一样。
在Python中,__getitem__
和__len__
是类的特殊方法(也叫魔术方法),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,从而赋予类特定的行为。
1. __getitem__
方法详解
__getitem__
方法用于让对象支持索引操作。当使用[]
语法访问对象元素时,Python会自动调用该方法。例如:
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]my_list_obj = MyList()
print(my_list_obj[2]) # 输出:30
通过定义__getitem__
方法,MyList
类的实例能够像Python内置的列表一样使用索引获取元素。
2. __len__
方法详解
__len__
方法用于返回对象中元素的数量。当使用内置函数len()
作用于对象时,Python会自动调用该方法。例如:
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)my_list_obj = MyList()
print(len(my_list_obj)) # 输出:5
这里定义的__len__
方法,使得MyList
类的实例可以像普通列表一样被len()
函数调用获取长度。
3. 自定义MNIST数据集类
为了更好地理解Dataset
类的使用,我们来实现一个简化版本的MNIST数据集类:
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加载图片路径和标签self.data, self.targets = fetch_mnist_data(root, train) # 假设 fetch_mnist_data 是一个函数self.transform = transform # 预处理操作def __len__(self):return len(self.data) # 返回样本总数def __getitem__(self, idx):# 获取指定索引的图像和标签img, target = self.data[idx], self.targets[idx]# 应用图像预处理if self.transform is not None:img = self.transform(img)return img, target # 返回处理后的图像和标签
在这个类中,__getitem__
方法负责根据索引获取单个样本,并应用预处理操作(如ToTensor
、Normalize
)。这就好比厨师在准备单个菜品时,会进行切菜、调味等预处理操作。
4. 可视化原始图像
为了查看数据集中的图像,我们可以定义一个可视化函数imshow
,并随机选择一张图片进行展示:
def imshow(img):img = img * 0.3081 + 0.1307 # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
image, label = train_dataset[sample_idx] # 获取图片和标签
print(f"Label: {label}")
imshow(image)
四、DataLoader类:定义“如何批量加载数据”和“加载策略”
DataLoader
类的职责是将Dataset
中的数据批量加载出来,并支持多线程加载,从而提高数据加载效率。它的使用非常简单:
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片
)
DataLoader
类的主要参数包括:
-
dataset
:要加载的数据集。 -
batch_size
:每个批次的样本数量。 -
shuffle
:是否随机打乱数据。 -
num_workers
:加载数据时使用的子进程数量,默认为0(不使用多进程)。
DataLoader
类可以看作是“服务员”,它将Dataset
类准备好的“菜品”(单个样本)按照订单(批量大小、是否打乱等策略)组合并上桌(批量加载)。
五、总结
通过以上内容的学习,我们可以对Dataset
类和DataLoader
类进行如下的总结:
维度 | Dataset | DataLoader |
---|---|---|
核心职责 | 定义“数据是什么”和“如何获取单个样本” | 定义“如何批量加载数据”和“加载策略” |
核心方法 | __getitem__ (获取单个样本)、__len__ (样本总数) | 无自定义方法,通过参数控制加载逻辑 |
预处理位置 | 在__getitem__ 中通过transform 执行预处理 | 无预处理逻辑,依赖Dataset 返回的预处理后数据 |
并行处理 | 无(仅单样本处理) | 支持多进程加载(num_workers>0 ) |
典型参数 | root (数据路径)、transform (预处理) | batch_size 、shuffle 、num_workers |
1. Dataset类的核心要点
-
定义数据的内容和格式:包括数据存储路径/来源、原始数据的读取方式、样本的预处理逻辑以及返回值格式。
-
实现两个核心方法:
__len__
和__getitem__
,这是PyTorch对数据集的基本要求,也是与DataLoader
兼容的关键。
2. DataLoader类的核心要点
-
定义数据的加载方式和批量处理逻辑:通过
batch_size
控制每个批次的样本数量,通过shuffle
决定是否随机打乱数据,通过num_workers
设置多进程加载的子进程数量。 -
依赖Dataset返回的预处理后数据:
DataLoader
本身不负责预处理,而是直接使用Dataset
返回的已经预处理好的数据。
3. 两者的协同工作
-
Dataset
类是“厨师”,负责准备单个样本,包括数据的读取和预处理。 -
DataLoader
类是“服务员”,负责将“厨师”准备好的单个样本按照订单(批量大小、是否打乱等策略)组合并上桌(批量加载)。
通过Dataset
类和DataLoader
类的协同工作,我们可以高效地处理和加载大规模数据集,为深度学习模型的训练提供有力支持。
@浙大疏锦行