【pytorch学习打卡挑战】day3 Pytorch的Dataset与DataLoader及DataLoader源码剖析
前言
本专题致力于学习Pytorch及其相关项目。
参照B站教程
【4、PyTorch的Dataset与DataLoader详细使用教程】
【5、深入剖析PyTorch DataLoader源码】
今日任务
第4、5个视频,主要围绕官方文档来介绍Dataset与DataLoader。
回顾时,建议直接看官方文档:
pytorch官方文档
内容总结
基础知识
PyTorch中的Dataset与DataLoader
PyTorch提供了Dataset
和DataLoader
两个核心类来处理数据加载与批处理。Dataset
负责存储样本及其对应的标签(处理单个样本),DataLoader
则围绕Dataset
进行迭代,提供批量加载、打乱数据和多进程加载等功能。
Dataset
是一个抽象类,用户自定义数据集需继承此类并实现__len__
和__getitem__
方法。__len__
返回数据集大小,__getitem__
根据索引返回单个样本。
DataLoader
接收一个Dataset
对象,并封装了批量加载、多线程等复杂逻辑。常用参数包括batch_size
(批大小)、shuffle
(是否打乱数据)和num_workers
(加载数据的进程数)。
自定义数据读取函数的方法
若要实现自定义数据读取逻辑,需继承Dataset
类并重写关键方法。以下是实现步骤:
继承torch.utils.data.Dataset
创建一个新类并继承Dataset
,初始化时加载数据路径或其他元信息。例如处理图像分类任务时,可保存图像路径和标签列表。
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data_path, transform=None):self.data = [...] # 加载数据路径或元信息self.labels = [...] # 加载对应标签self.transform = transform # 数据预处理(如标准化、增强)
实现__len__
方法
返回数据集的总样本数,通常为数据列表的长度。
def __len__(self):return len(self.data)
实现__getitem__
方法
根据索引返回单个样本和标签。在此方法中完成数据读取(如打开图像文件)和预处理。若提供transform
,需在此处调用。
def __getitem__(self, idx):sample = self.data[idx] # 读取数据(如用PIL打开图像)label = self.labels[idx]if self.transform:sample = self.transform(sample) # 应用预处理return sample, label
使用DataLoader
封装
实例化自定义Dataset
后,通过DataLoader
实现批量加载和多进程加速。
from torch.utils.data import DataLoaderdataset = CustomDataset(data_path="path/to/data", transform=my_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
关键注意事项
- 数据预处理:在
__getitem__
中完成数据读取和转换,避免在__init__
中加载全部数据以减少内存占用。 - 线程安全:若使用多进程(
num_workers>0
),确保数据读取逻辑是线程安全的。 - 效率优化:对于大规模数据,建议使用延迟加载(lazy loading)或内存映射技术。
通过上述方法,用户可灵活适配各种数据格式(如文本、音频、视频)和存储结构(如文件夹、数据库)。
DataLoader 的基本结构
PyTorch 的 DataLoader
是用于批量加载数据的工具,支持多进程加载和自定义数据采样。其核心功能包括:
- 数据集的迭代
- 批处理(batching)
- 数据打乱(shuffling)
- 多进程加载(multiprocessing)
DataLoader
的源码位于 torch/utils/data/dataloader.py
,主要由以下几个关键组件构成:
DataLoader
类:用户接口,封装了数据加载逻辑。_DataLoaderIter
(或_SingleProcessDataLoaderIter
/_MultiProcessingDataLoaderIter
):实际的迭代器实现。Sampler
:控制数据采样的策略(如顺序采样、随机采样)。Dataset
:定义数据集的接口。
DataLoader 的核心实现
DataLoader 类
DataLoader
类是用户直接调用的接口,主要参数包括:
dataset
:继承自torch.utils.data.Dataset
的对象。batch_size
:批处理大小。shuffle
:是否打乱数据。sampler
:自定义采样器。num_workers
:加载数据的进程数。collate_fn
:批处理数据的函数。
初始化时,DataLoader
会根据参数选择单进程或多进程模式:
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None):# 参数校验与初始化逻辑...
迭代器实现
DataLoader
的迭代器分为两种:
- 单进程模式 (
_SingleProcessDataLoaderIter
):直接在主进程加载数据。 - 多进程模式 (
_MultiProcessingDataLoaderIter
):通过multiprocessing
启动子进程加载数据。
迭代器的核心方法是 __next__
,用于获取下一批数据:
def __next__(self):# 从数据队列中获取批数据data = self._get_data()# 处理数据加载异常if self._num_yielded >= len(self):raise StopIterationreturn data
多进程数据加载
多进程模式下,DataLoader
使用 torch.multiprocessing
创建多个工作进程(worker),每个 worker 通过 worker_init_fn
初始化,并通过队列(queue
)将数据返回给主进程。
关键流程:
- Worker 初始化:每个 worker 加载数据集并通过
collate_fn
处理数据。 - 数据预取:worker 提前加载数据到队列中,避免主进程等待。
- 异常处理:捕获 worker 中的异常并通过队列传递到主进程。
代码片段:
def _worker_loop(dataset, index_queue, data_queue, collate_fn, ...):# 初始化 workertorch.set_num_threads(1)while True:idx = index_queue.get()if idx is None:breaktry:data = collate_fn([dataset[i] for i in idx])data_queue.put((idx, data))except Exception:# 异常处理data_queue.put((idx, ExceptionWrapper(...)))
Sampler 和 BatchSampler
Sampler
控制数据索引的生成方式:
SequentialSampler
:顺序采样。RandomSampler
:随机采样。BatchSampler
:将采样器生成的索引分组为批次。
BatchSampler
的示例:
class BatchSampler(Sampler):def __init__(self, sampler, batch_size, drop_last):self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if not self.drop_last and batch:yield batch
关键优化技术
- Pin Memory:将数据加载到固定的 GPU 内存中,加速数据传输。
- Prefetching:worker 提前加载下一批数据,减少主进程等待时间。
- Shared Memory:在多进程模式下,通过共享内存减少数据复制开销。
总结
DataLoader
的核心设计围绕高效数据加载展开,通过多进程、采样器和批处理机制实现高性能数据流水线。其源码逻辑清晰,模块化设计便于扩展,是 PyTorch 训练流程中的重要组件。