PyTorch DataLoader 高级用法
好的,这是一个关于 PyTorch DataLoader
中 sampler
和 collate_fn
等参数的非常好的问题。这些参数是 PyTorch 数据加载管道的核心,理解它们能让你高度自定义数据处理流程,以适应各种复杂的任务需求。
我将为你详细解释指定这些参数的方式以及它们之间的区别,内容将涵盖:
DataLoader
的核心工作流程sampler
和batch_sampler
:控制数据采样的顺序和方式- 默认行为 (
shuffle=True/False
) - 指定
sampler
:精细控制样本顺序 - 指定
batch_sampler
:精细控制批次构成 - 三者之间的关系和互斥性
- 默认行为 (
collate_fn
:自定义样本到批次的转换- 默认行为 (
default_collate
) - 指定自定义
collate_fn
- 默认行为 (
- 总结与最佳实践
1. DataLoader
的核心工作流程
在深入细节之前,我们先理解 DataLoader
在一个 epoch 中是如何工作的:
- 启动迭代: 当你写
for batch in data_loader:
时,DataLoader
的迭代器被创建。 - 获取索引:
DataLoader
首先向sampler
(或batch_sampler
) 请求样本的索引 (indices)。- 如果使用
sampler
,它会一个一个地返回样本索引。DataLoader
内部会根据batch_size
和drop_last
将这些索引组成一个批次的索引列表。 - 如果使用
batch_sampler
,它会一次性返回一个完整的、已经组织好的批次索引列表。
- 如果使用
- 获取数据:
DataLoader
使用上一步得到的批次索引列表,通过dataset[index]
的方式从你的Dataset
对象中获取一批数据样本。这时你会得到一个列表,列表中的每个元素都是Dataset
的__getitem__
方法返回的结果。例如[sample1, sample2, sample3, ...]
. - 整理批次:
DataLoader
将这个样本列表传递给collate_fn
函数。 - 返回批次:
collate_fn
函数将样本列表处理(例如,堆叠成 Tensor)并返回一个最终的批次(batch)。这个批次就是你在for
循环中接收到的batch
变量。
一个简化的工作流程图
现在,我们来详细看 sampler
和 collate_fn
。
2. sampler
和 batch_sampler
:控制数据采样
sampler
的核心职责是生成一系列索引,决定了从数据集中抽取样本的顺序。
2.1 默认行为 (shuffle
)
这是最简单、最常见的方式。你在创建 DataLoader
时,通过 shuffle
参数来控制。
import torch
from torch.utils.data import TensorDataset, DataLoader# 创建一个简单的数据集
data = torch.randn(10, 3) # 10个样本,每个样本3个特征
labels = torch.arange(10) # 标签为 0 到 9
dataset = TensorDataset(data, labels)# 方式一:顺序采样
# shuffle=False (默认)
loader_seq = DataLoader(dataset, batch_size=4, shuffle=False)
print("顺序采样 (shuffle=False):")
for _, batch_labels in loader_seq:print(batch_labels.tolist())
# 输出:
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]# 方式二:随机采样
# shuffle=True
loader_rand = DataLoader(dataset, batch_size=4, shuffle=True)
print("\n随机采样 (shuffle=True):")
for _, batch_labels in loader_rand:print(batch_labels.tolist())
# 输出 (每次可能不同):
# [3, 8, 1, 5]
# [0, 9, 2, 6]
# [7, 4]
工作原理:
shuffle=False
:DataLoader
内部会使用SequentialSampler
,它按照0, 1, 2, ...
的顺序生成索引。shuffle=True
:DataLoader
内部会使用RandomSampler
,它会在每个 epoch 开始时,将所有索引(0
到len(dataset)-1
)随机打乱,然后按打乱后的顺序生成索引。
区别:
- 这是最上层的抽象,简单直接。
- 你无法进行更复杂的采样控制,比如类别均衡采样。
2.2 指定 sampler
当你需要比简单的“顺序”或“随机”更复杂的采样策略时,就需要手动创建一个 sampler
对象并传递给 DataLoader
。
重要: 当你手动指定 sampler
时,必须将 shuffle
设置为 False
(或不设置,默认为 False
)。因为 sampler
已经定义了索引的生成顺序,shuffle=True
会与之冲突。
PyTorch 内置了一些有用的 sampler
:
SequentialSampler
: 按顺序采样,等同于shuffle=False
。RandomSampler
: 随机采样,等同于shuffle=True
。SubsetRandomSampler
: 在一个给定的索引子集内进行随机采样。常用于交叉验证。WeightedRandomSampler
: 根据每个样本的权重进行采样。常用于处理类别不均衡问题。
示例:使用 WeightedRandomSampler
进行类别均衡采样
假设我们有一个不均衡的数据集,类别 ‘0’ 的样本远多于类别 ‘1’。我们希望在训练时,每个批次中类别 ‘0’ 和 ‘1’ 的样本数量大致相等。
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSamplerclass ImbalancedDataset(Dataset):def __init__(self):# 90个类别0的样本, 10个类别1的样本self.data = torch.randn(100, 5)self.labels = torch.cat([torch.zeros(90, dtype=torch.long), torch.ones(10, dtype=torch.long)])def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.data[idx], self.labels[idx]dataset = ImbalancedDataset()# 计算每个样本的权重
# 类别0的权重: 1 / 90
# 类别1的权重: 1 / 10
class_counts = [90.0, 10.0]
num_samples = sum(class_counts)
class_weights = [num_samples / class_count for class_count in class_counts]# 为数据集中的每个样本分配权重
sample_weights = [class_weights[label] for label in dataset.labels]
sample_weights = torch.DoubleTensor(sample_weights)# 创建 WeightedRandomSampler
# num_samples: 每个epoch采样的总数
# replacement=True: 允许重复采样(对于不均衡问题通常需要)
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)# 创建 DataLoader,注意 shuffle=False
# 因为 sampler 已经决定了采样顺序
loader_balanced = DataLoader(dataset, batch_size=10, sampler=sampler)print("使用 WeightedRandomSampler 进行均衡采样:")
for epoch in range(2):print(f"Epoch {epoch+1}")for _, batch_labels in loader_balanced:print(f" Batch labels: {batch_labels.tolist()}")print(f" Class counts: 0={torch.sum(batch_labels == 0).item()}, 1={torch.sum(batch_labels == 1).item()}")# 输出 (每次可能不同, 但类别1的比例会显著提高):
# Epoch 1
# Batch labels: [1, 1, 1, 0, 1, 0, 1, 0, 1, 0]
# Class counts: 0=4, 1=6
# ... (其他批次)
# Epoch 2
# ...
自定义 sampler
: 你还可以通过继承 torch.utils.data.Sampler
并实现 __iter__
和 __len__
方法来创建自己的采样器。
from torch.utils.data import Sampler
import numpy as npclass EvenOddSampler(Sampler):"""一个自定义的采样器,先采样所有偶数索引,再采样所有奇数索引"""def __init__(self, data_source):self.data_source = data_sourceself.even_indices = [i for i in range(len(data_source)) if i % 2 == 0]self.odd_indices = [i for i in range(len(data_source)) if i % 2 != 0]def __iter__(self):# 返回一个索引的迭代器return iter(self.even_indices + self.odd_indices)def __len__(self):return len(self.data_source)# 使用自定义sampler
dataset_simple = TensorDataset(torch.arange(10))
custom_sampler = EvenOddSampler(dataset_simple)
loader_custom = DataLoader(dataset_simple, batch_size=4, sampler=custom_sampler)print("\n使用自定义 EvenOddSampler:")
for batch in loader_custom:print(batch[0].tolist())
# 输出:
# [0, 2, 4, 6]
# [8, 1, 3, 5]
# [7, 9]
2.3 指定 batch_sampler
batch_sampler
是一个更底层的工具。它不像 sampler
那样一次返回一个索引,而是一次返回一个批次的索引列表。
重要: 当你手动指定 batch_sampler
时,以下参数将被忽略且必须不设置:batch_size
, shuffle
, sampler
, drop_last
。因为 batch_sampler
已经完全接管了批次的形成方式。
这在你需要对批次内的样本构成有特殊要求时非常有用。例如,在 NLP 中,为了减少 padding,你可能希望将长度相近的句子放在同一个批次里。
示例:使用 BatchSampler
BatchSampler
是一个包装器,它接收一个 sampler
和 batch_size
、drop_last
参数,然后生成批次索引。这实际上是 DataLoader
内部的默认行为。
from torch.utils.data import BatchSampler, SequentialSamplerdataset_simple = TensorDataset(torch.arange(20))# 创建一个顺序采样器
seq_sampler = SequentialSampler(dataset_simple)# 使用 BatchSampler 包装它
batch_sampler = BatchSampler(seq_sampler, batch_size=5, drop_last=False)# 创建 DataLoader,注意其他参数都不能设置
loader_batch_sampler = DataLoader(dataset_simple, batch_sampler=batch_sampler)print("\n使用 BatchSampler:")
for batch in loader_batch_sampler:print(batch[0].tolist())
# 输出:
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
# [10, 11, 12, 13, 14]
# [15, 16, 17, 18, 19]
自定义 batch_sampler
:这才是 batch_sampler
真正强大的地方。你可以继承 torch.utils.data.Sampler
(注意,BatchSampler
也继承自 Sampler
)并实现 __iter__
,让它 yield
一个个批次的索引列表。
class GroupLengthBatchSampler(Sampler):"""一个自定义的BatchSampler,尝试将长度相近的样本分到同一个批次。这是一个简化的实现。"""def __init__(self, data_source, batch_size):self.data_source = data_sourceself.batch_size = batch_size# 假设 data_source 有一个 'lengths' 属性# 按长度排序索引self.sorted_indices = np.argsort([len(x) for x in self.data_source.texts])def __iter__(self):# 将排序后的索引分块for i in range(0, len(self.sorted_indices), self.batch_size):yield self.sorted_indices[i : i + self.batch_size]def __len__(self):return (len(self.sorted_indices) + self.batch_size - 1) // self.batch_size# 假设有一个带文本的数据集
class TextDataset(Dataset):def __init__(self):self.texts = ["short", "a bit longer", "very very long sentence", "medium one","tiny", "another medium one", "this is also a very long sentence"]def __len__(self): return len(self.texts)def __getitem__(self, idx): return self.texts[idx]text_dataset = TextDataset()
my_batch_sampler = GroupLengthBatchSampler(text_dataset, batch_size=2)# collate_fn 在这里只是为了打印,后面会详细讲
loader_grouped = DataLoader(text_dataset, batch_sampler=my_batch_sampler, collate_fn=lambda x: x)print("\n使用自定义 GroupLengthBatchSampler:")
for batch in loader_grouped:print(f"Batch: {batch}, Lengths: {[len(s) for s in batch]}")# 输出 (按长度排序后的批次):
# Batch: ['short', 'tiny'], Lengths: [5, 4]
# Batch: ['medium one', 'a bit longer'], Lengths: [10, 12]
# Batch: ['another medium one', 'very very long sentence'], Lengths: [18, 25]
# Batch: ['this is also a very long sentence'], Lengths: [33]
2.4 三者关系总结
方式 | 作用 | 如何工作 | 互斥参数 | 适用场景 |
---|---|---|---|---|
shuffle=True/False | 控制是随机还是顺序采样 | 内部使用 RandomSampler 或 SequentialSampler | 无 | 最简单、最常见的场景。 |
sampler=... | 定义单个样本的抽取顺序 | 提供一个生成索引序列的迭代器 | shuffle | 需要复杂采样逻辑,如类别均衡、子集采样等。 |
batch_sampler=... | 定义批次索引列表的生成方式 | 提供一个生成索引列表的迭代器 | batch_size , shuffle , sampler , drop_last | 需要控制批次内部的构成,如按长度分组以减少padding。 |
3. collate_fn
:自定义样本到批次的转换
collate_fn
的职责是在 DataLoader
从 Dataset
获取到一个样本列表后,将这个列表整理(collate)成一个批次。
3.1 默认行为 (default_collate
)
如果你不指定 collate_fn
,DataLoader
会使用 torch.utils.data.default_collate
。它的行为是:
- 它会尝试将输入样本列表中的每个元素(通常是元组,如
(data, label)
)的对应部分堆叠(stack)起来。 - 它能处理 PyTorch Tensors, NumPy arrays, Python numbers 和 strings。
- 对于 Tensors,它会使用
torch.stack
在第0维(批次维)上进行堆叠。 - 这要求一个批次内的所有样本都有相同的形状。
默认行为失败的场景:
当一个批次内的样本形状不同时,default_collate
会失败。最常见的例子是 NLP 中的变长序列或计算机视觉中的不同尺寸图像。
# 失败的例子
dataset_variable_len = [torch.tensor([1,2,3]), torch.tensor([4,5])]
try:# 默认的 collate_fn 无法处理不同长度的 tensorloader_fail = DataLoader(dataset_variable_len, batch_size=2)for batch in loader_fail:pass
except RuntimeError as e:print(f"默认 collate_fn 失败: {e}")
# 输出: 默认 collate_fn 失败: stack expects each tensor to be equal size, but got [3] at entry 0 and [2] at entry 1
3.2 指定自定义 collate_fn
为了解决上述问题,你可以提供一个自定义的 collate_fn
函数。这个函数接收一个列表,列表中的每个元素都是 Dataset
的 __getitem__
的返回值。你需要在这个函数里实现将这个列表转换成一个批次的逻辑。
示例:为变长序列实现 padding
这是 collate_fn
最经典的应用。
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequenceclass VariableLengthDataset(Dataset):def __init__(self):self.data = [(torch.tensor([1, 2, 3]), 0),(torch.tensor([4, 5]), 1),(torch.tensor([6, 7, 8, 9]), 0),(torch.tensor([10]), 1)]def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx] # 返回 (sequence, label)def custom_collate_fn(batch):"""自定义的 collate_fn 函数,用于处理变长序列。:param batch: 一个列表,其中每个元素是 Dataset 的 __getitem__ 返回值。例如: [ (tensor([1,2,3]), 0), (tensor([4,5]), 1) ]"""# 1. 将数据和标签分离sequences = [item[0] for item in batch]labels = [item[1] for item in batch]# 2. 对序列进行 padding# pad_sequence 会自动将序列填充到该批次中最长序列的长度# batch_first=True 表示返回的 tensor 形状为 (batch_size, seq_len)padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)# 3. 将标签转换为 Tensorlabels = torch.LongTensor(labels)# 4. 返回处理好的批次return padded_sequences, labels# 创建 DataLoader 并指定自定义的 collate_fn
dataset = VariableLengthDataset()
loader_padded = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=custom_collate_fn)print("\n使用自定义 collate_fn 进行 padding:")
for seq_batch, label_batch in loader_padded:print("Padded Sequences Batch:")print(seq_batch)print("Labels Batch:")print(label_batch)print("-" * 20)# 可能的输出:
# 使用自定义 collate_fn 进行 padding:
# Padded Sequences Batch:
# tensor([[ 6, 7, 8, 9],
# [10, 0, 0, 0]])
# Labels Batch:
# tensor([0, 1])
# --------------------
# Padded Sequences Batch:
# tensor([[1, 2, 3],
# [4, 5, 0]])
# Labels Batch:
# tensor([0, 1])
# --------------------
collate_fn
的区别:
- 它不关心样本的抽取顺序,只关心拿到一批样本后如何组合。
- 它的功能与
sampler
是正交的、互补的。你可以同时使用自定义的sampler
和自定义的collate_fn
。
4. 总结与最佳实践
-
简单场景: 如果你只需要顺序或随机打乱数据,并且所有数据样本形状一致,那么直接使用
DataLoader
的shuffle=True/False
和batch_size
参数就足够了。 -
类别不均衡/特定顺序: 如果你需要解决类别不均衡问题(使用
WeightedRandomSampler
)或按特定规则(如先训练简单样本,后训练难样本)抽取数据,那么你需要自定义sampler
,并记得设置shuffle=False
。 -
优化批次内构成: 如果你想通过将相似长度/大小的样本组合在一起以优化计算效率(例如,减少 NLP 中的 padding 或 CV 中可变尺寸图像的处理开销),那么你需要自定义
batch_sampler
。这是最底层的控制,它会覆盖batch_size
,shuffle
,sampler
等参数。 -
处理可变数据: 如果你的数据样本形状不一(如变长文本、不同尺寸的图片),导致默认的堆叠操作失败,那么你需要自定义
collate_fn
。在这个函数里,你可以实现 padding、图像缩放等预处理,将一批异构的样本转换成一个规整的 Tensor 批次。
黄金组合: 在复杂的任务中,你经常会同时使用这些工具。例如,在 NLP 任务中,一个高效的 DataLoader
可能会:
- 使用一个自定义的
batch_sampler
,它首先根据句子长度对所有样本进行粗略分组,然后在每个组内进行随机采样,最后形成批次。这能保证批次内长度相近,同时保留一定的随机性。 - 使用一个自定义的
collate_fn
,它接收batch_sampler
给出的索引所对应的一批样本,然后对它们进行精确的 padding,并同时处理标签和其它元数据。
通过灵活组合 sampler
, batch_sampler
和 collate_fn
,你可以为几乎任何数据类型和训练策略构建出高效、定制化的数据加载管道。