PyTorch 深度学习基础:Dataset 与 DataLoader 详解
文章目录
- 一、为什么需要 Dataset 和 DataLoader?
- 二、Dataset:数据集的抽象类
- 自定义 Dataset 示例
- 三、DataLoader:数据批处理与加载工具
- 四、Dataset + DataLoader 完整示例
- 五、DataLoader 的常用参数
- 六、Dataset 与 DataLoader 的关系图
- 七、总结
一、为什么需要 Dataset 和 DataLoader?
在深度学习中通常需要完成以下任务:
- 从文件中读取数据(图片、文本、CSV等)
- 进行预处理或数据增强
- 每次训练读取一个批次(batch)数据
- 打乱(shuffle)或并行加载数据以加快训练
如果手动写这些逻辑,代码会变得冗长且不易维护。
在使用 PyTorch 进行深度学习训练时,数据加载是不可或缺的一步。
PyTorch 为此提供了两个抽象接口来标准化数据加载流程:
| 组件 | 作用 |
|---|---|
Dataset | 定义数据集的内容与读取方式 |
DataLoader | 负责批量加载、打乱、并行读取数据 |
PyTorch 提供了两个核心组件来处理数据集:
torch.utils.data.Datasettorch.utils.data.DataLoader
这两者一起构成了 PyTorch 的数据加载体系。
二、Dataset:数据集的抽象类
Dataset 是 PyTorch 中所有数据集的基类。
它定义了两个必须实现的方法:
__len__() # 返回数据集的样本总数
__getitem__(index) # 根据索引返回一个样本(数据 + 标签)
自定义 Dataset 示例
假设我们有一个简单的 CSV 文件,存储了输入特征和标签:
x1, x2, y
1.0, 2.0, 0
2.0, 3.0, 1
3.0, 4.0, 0
我们可以这样定义自定义数据集:
import torch
from torch.utils.data import Dataset
import pandas as pdclass MyDataset(Dataset):def __init__(self, csv_file):self.data = pd.read_csv(csv_file)def __len__(self):return len(self.data)def __getitem__(self, idx):x = self.data.iloc[idx, :-1].valuesy = self.data.iloc[idx, -1]return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)
使用方法:
dataset = MyDataset("data.csv")
print(len(dataset)) # 数据样本数量
print(dataset[0]) # 第一个样本 (x, y)
💡 提示:
PyTorch 已经为常用数据集(如 MNIST、CIFAR10、ImageNet 等)提供了内置实现,
也可以直接使用,例如:
from torchvision import datasets, transformstransform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
三、DataLoader:数据批处理与加载工具
DataLoader 是一个高效的数据加载器,用来按批次读取 Dataset 中的数据。
它可以自动:
- 按批加载(batch)
- 打乱数据(shuffle)
- 并行加载数据(num_workers)
基本用法
from torch.utils.data import DataLoadertrain_loader = DataLoader(dataset, batch_size=4, shuffle=True)
然后可以像这样迭代读取数据:
for batch_idx, (x, y) in enumerate(train_loader):print(batch_idx, x, y)
关于Python enumerate()函数用法,可参考Python 基础详解:enumerate() 函数-CSDN博客
输出示例:
0 tensor([[1., 2.], [3., 4.], [2., 3.], [4., 5.]]) tensor([0, 0, 1, 1])
1 tensor([[5., 6.], [7., 8.], [6., 7.], [8., 9.]]) tensor([1, 0, 1, 0])
四、Dataset + DataLoader 完整示例
以下是一个完整的可运行例子:
import torch
from torch.utils.data import Dataset, DataLoader# 1. 自定义数据集
class MyDataset(Dataset):def __init__(self):self.data = torch.arange(10)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.data[idx] ** 2 # 返回 x 和 x²# 2. 创建数据集
dataset = MyDataset()# 3. 创建 DataLoader
loader = DataLoader(dataset, batch_size=3, shuffle=True)# 4. 迭代读取
for x, y in loader:print(f"x: {x}, y: {y}")
输出示例:
x: tensor([7, 6, 2]), y: tensor([49, 36, 4])
x: tensor([4, 5, 1]), y: tensor([16, 25, 1])
x: tensor([8, 3, 0]), y: tensor([64, 9, 0])
x: tensor([9]), y: tensor([81])
可以看到:
DataLoader自动把样本打乱- 每次取 3 个数据(batch size=3)
- 返回的是
(x_batch, y_batch)
五、DataLoader 的常用参数
| 参数 | 作用 | 默认值 |
|---|---|---|
dataset | 数据集对象 | 必须 |
batch_size | 每个批次的样本数 | 1 |
shuffle | 是否打乱顺序 | False |
num_workers | 并行加载数据的线程数 | 0(Windows 推荐 0) |
drop_last | 是否丢弃最后一个不完整批次 | False |
pin_memory | 是否将张量拷贝到 CUDA 固定内存中(提高 GPU 速度) | False |
示例:
train_loader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=2,drop_last=True,pin_memory=True
)
六、Dataset 与 DataLoader 的关系图

Dataset 负责“是什么数据”,DataLoader 负责“怎么取数据”。
七、总结
| 组件 | 功能 | 是否必须实现 |
|---|---|---|
Dataset.__len__() | 返回样本数量 | ✅ 是 |
Dataset.__getitem__() | 返回单个样本 | ✅ 是 |
DataLoader | 负责批次加载、打乱、并行等 | ✅ 推荐使用 |
