当前位置: 首页 > news >正文

PyTorch DataLoader 高级用法

好的,这是一个关于 PyTorch DataLoadersamplercollate_fn 等参数的非常好的问题。这些参数是 PyTorch 数据加载管道的核心,理解它们能让你高度自定义数据处理流程,以适应各种复杂的任务需求。

我将为你详细解释指定这些参数的方式以及它们之间的区别,内容将涵盖:

  1. DataLoader 的核心工作流程
  2. samplerbatch_sampler:控制数据采样的顺序和方式
    • 默认行为 (shuffle=True/False)
    • 指定 sampler:精细控制样本顺序
    • 指定 batch_sampler:精细控制批次构成
    • 三者之间的关系和互斥性
  3. collate_fn:自定义样本到批次的转换
    • 默认行为 (default_collate)
    • 指定自定义 collate_fn
  4. 总结与最佳实践

1. DataLoader 的核心工作流程

在深入细节之前,我们先理解 DataLoader 在一个 epoch 中是如何工作的:

  1. 启动迭代: 当你写 for batch in data_loader: 时,DataLoader 的迭代器被创建。
  2. 获取索引: DataLoader 首先向 sampler (或 batch_sampler) 请求样本的索引 (indices)。
    • 如果使用 sampler,它会一个一个地返回样本索引。DataLoader 内部会根据 batch_sizedrop_last 将这些索引组成一个批次的索引列表。
    • 如果使用 batch_sampler,它会一次性返回一个完整的、已经组织好的批次索引列表。
  3. 获取数据: DataLoader 使用上一步得到的批次索引列表,通过 dataset[index] 的方式从你的 Dataset 对象中获取一批数据样本。这时你会得到一个列表,列表中的每个元素都是 Dataset__getitem__ 方法返回的结果。例如 [sample1, sample2, sample3, ...].
  4. 整理批次: DataLoader 将这个样本列表传递给 collate_fn 函数。
  5. 返回批次: collate_fn 函数将样本列表处理(例如,堆叠成 Tensor)并返回一个最终的批次(batch)。这个批次就是你在 for 循环中接收到的 batch 变量。

一个简化的工作流程图

现在,我们来详细看 samplercollate_fn


2. samplerbatch_sampler:控制数据采样

sampler 的核心职责是生成一系列索引,决定了从数据集中抽取样本的顺序。

2.1 默认行为 (shuffle)

这是最简单、最常见的方式。你在创建 DataLoader 时,通过 shuffle 参数来控制。

import torch
from torch.utils.data import TensorDataset, DataLoader# 创建一个简单的数据集
data = torch.randn(10, 3) # 10个样本,每个样本3个特征
labels = torch.arange(10) # 标签为 0 到 9
dataset = TensorDataset(data, labels)# 方式一:顺序采样
# shuffle=False (默认)
loader_seq = DataLoader(dataset, batch_size=4, shuffle=False)
print("顺序采样 (shuffle=False):")
for _, batch_labels in loader_seq:print(batch_labels.tolist())
# 输出:
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]# 方式二:随机采样
# shuffle=True
loader_rand = DataLoader(dataset, batch_size=4, shuffle=True)
print("\n随机采样 (shuffle=True):")
for _, batch_labels in loader_rand:print(batch_labels.tolist())
# 输出 (每次可能不同):
# [3, 8, 1, 5]
# [0, 9, 2, 6]
# [7, 4]

工作原理:

  • shuffle=False: DataLoader 内部会使用 SequentialSampler,它按照 0, 1, 2, ... 的顺序生成索引。
  • shuffle=True: DataLoader 内部会使用 RandomSampler,它会在每个 epoch 开始时,将所有索引(0len(dataset)-1)随机打乱,然后按打乱后的顺序生成索引。

区别:

  • 这是最上层的抽象,简单直接。
  • 你无法进行更复杂的采样控制,比如类别均衡采样。
2.2 指定 sampler

当你需要比简单的“顺序”或“随机”更复杂的采样策略时,就需要手动创建一个 sampler 对象并传递给 DataLoader

重要: 当你手动指定 sampler 时,必须将 shuffle 设置为 False (或不设置,默认为 False)。因为 sampler 已经定义了索引的生成顺序,shuffle=True 会与之冲突。

PyTorch 内置了一些有用的 sampler

  • SequentialSampler: 按顺序采样,等同于 shuffle=False
  • RandomSampler: 随机采样,等同于 shuffle=True
  • SubsetRandomSampler: 在一个给定的索引子集内进行随机采样。常用于交叉验证。
  • WeightedRandomSampler: 根据每个样本的权重进行采样。常用于处理类别不均衡问题。

示例:使用 WeightedRandomSampler 进行类别均衡采样

假设我们有一个不均衡的数据集,类别 ‘0’ 的样本远多于类别 ‘1’。我们希望在训练时,每个批次中类别 ‘0’ 和 ‘1’ 的样本数量大致相等。

import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSamplerclass ImbalancedDataset(Dataset):def __init__(self):# 90个类别0的样本, 10个类别1的样本self.data = torch.randn(100, 5)self.labels = torch.cat([torch.zeros(90, dtype=torch.long), torch.ones(10, dtype=torch.long)])def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.data[idx], self.labels[idx]dataset = ImbalancedDataset()# 计算每个样本的权重
# 类别0的权重: 1 / 90
# 类别1的权重: 1 / 10
class_counts = [90.0, 10.0]
num_samples = sum(class_counts)
class_weights = [num_samples / class_count for class_count in class_counts]# 为数据集中的每个样本分配权重
sample_weights = [class_weights[label] for label in dataset.labels]
sample_weights = torch.DoubleTensor(sample_weights)# 创建 WeightedRandomSampler
# num_samples: 每个epoch采样的总数
# replacement=True: 允许重复采样(对于不均衡问题通常需要)
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)# 创建 DataLoader,注意 shuffle=False
# 因为 sampler 已经决定了采样顺序
loader_balanced = DataLoader(dataset, batch_size=10, sampler=sampler)print("使用 WeightedRandomSampler 进行均衡采样:")
for epoch in range(2):print(f"Epoch {epoch+1}")for _, batch_labels in loader_balanced:print(f"  Batch labels: {batch_labels.tolist()}")print(f"  Class counts: 0={torch.sum(batch_labels == 0).item()}, 1={torch.sum(batch_labels == 1).item()}")# 输出 (每次可能不同, 但类别1的比例会显著提高):
# Epoch 1
#   Batch labels: [1, 1, 1, 0, 1, 0, 1, 0, 1, 0]
#   Class counts: 0=4, 1=6
#   ... (其他批次)
# Epoch 2
#   ...

自定义 sampler: 你还可以通过继承 torch.utils.data.Sampler 并实现 __iter____len__ 方法来创建自己的采样器。

from torch.utils.data import Sampler
import numpy as npclass EvenOddSampler(Sampler):"""一个自定义的采样器,先采样所有偶数索引,再采样所有奇数索引"""def __init__(self, data_source):self.data_source = data_sourceself.even_indices = [i for i in range(len(data_source)) if i % 2 == 0]self.odd_indices = [i for i in range(len(data_source)) if i % 2 != 0]def __iter__(self):# 返回一个索引的迭代器return iter(self.even_indices + self.odd_indices)def __len__(self):return len(self.data_source)# 使用自定义sampler
dataset_simple = TensorDataset(torch.arange(10))
custom_sampler = EvenOddSampler(dataset_simple)
loader_custom = DataLoader(dataset_simple, batch_size=4, sampler=custom_sampler)print("\n使用自定义 EvenOddSampler:")
for batch in loader_custom:print(batch[0].tolist())
# 输出:
# [0, 2, 4, 6]
# [8, 1, 3, 5]
# [7, 9]
2.3 指定 batch_sampler

batch_sampler 是一个更底层的工具。它不像 sampler 那样一次返回一个索引,而是一次返回一个批次的索引列表

重要: 当你手动指定 batch_sampler 时,以下参数将被忽略且必须不设置:batch_size, shuffle, sampler, drop_last。因为 batch_sampler 已经完全接管了批次的形成方式。

这在你需要对批次内的样本构成有特殊要求时非常有用。例如,在 NLP 中,为了减少 padding,你可能希望将长度相近的句子放在同一个批次里。

示例:使用 BatchSampler

BatchSampler 是一个包装器,它接收一个 samplerbatch_sizedrop_last 参数,然后生成批次索引。这实际上是 DataLoader 内部的默认行为。

from torch.utils.data import BatchSampler, SequentialSamplerdataset_simple = TensorDataset(torch.arange(20))# 创建一个顺序采样器
seq_sampler = SequentialSampler(dataset_simple)# 使用 BatchSampler 包装它
batch_sampler = BatchSampler(seq_sampler, batch_size=5, drop_last=False)# 创建 DataLoader,注意其他参数都不能设置
loader_batch_sampler = DataLoader(dataset_simple, batch_sampler=batch_sampler)print("\n使用 BatchSampler:")
for batch in loader_batch_sampler:print(batch[0].tolist())
# 输出:
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
# [10, 11, 12, 13, 14]
# [15, 16, 17, 18, 19]

自定义 batch_sampler:这才是 batch_sampler 真正强大的地方。你可以继承 torch.utils.data.Sampler(注意,BatchSampler 也继承自 Sampler)并实现 __iter__,让它 yield 一个个批次的索引列表。

class GroupLengthBatchSampler(Sampler):"""一个自定义的BatchSampler,尝试将长度相近的样本分到同一个批次。这是一个简化的实现。"""def __init__(self, data_source, batch_size):self.data_source = data_sourceself.batch_size = batch_size# 假设 data_source 有一个 'lengths' 属性# 按长度排序索引self.sorted_indices = np.argsort([len(x) for x in self.data_source.texts])def __iter__(self):# 将排序后的索引分块for i in range(0, len(self.sorted_indices), self.batch_size):yield self.sorted_indices[i : i + self.batch_size]def __len__(self):return (len(self.sorted_indices) + self.batch_size - 1) // self.batch_size# 假设有一个带文本的数据集
class TextDataset(Dataset):def __init__(self):self.texts = ["short", "a bit longer", "very very long sentence", "medium one","tiny", "another medium one", "this is also a very long sentence"]def __len__(self): return len(self.texts)def __getitem__(self, idx): return self.texts[idx]text_dataset = TextDataset()
my_batch_sampler = GroupLengthBatchSampler(text_dataset, batch_size=2)# collate_fn 在这里只是为了打印,后面会详细讲
loader_grouped = DataLoader(text_dataset, batch_sampler=my_batch_sampler, collate_fn=lambda x: x)print("\n使用自定义 GroupLengthBatchSampler:")
for batch in loader_grouped:print(f"Batch: {batch}, Lengths: {[len(s) for s in batch]}")# 输出 (按长度排序后的批次):
# Batch: ['short', 'tiny'], Lengths: [5, 4]
# Batch: ['medium one', 'a bit longer'], Lengths: [10, 12]
# Batch: ['another medium one', 'very very long sentence'], Lengths: [18, 25]
# Batch: ['this is also a very long sentence'], Lengths: [33]
2.4 三者关系总结
方式作用如何工作互斥参数适用场景
shuffle=True/False控制是随机还是顺序采样内部使用 RandomSamplerSequentialSampler最简单、最常见的场景。
sampler=...定义单个样本的抽取顺序提供一个生成索引序列的迭代器shuffle需要复杂采样逻辑,如类别均衡、子集采样等。
batch_sampler=...定义批次索引列表的生成方式提供一个生成索引列表的迭代器batch_size, shuffle, sampler, drop_last需要控制批次内部的构成,如按长度分组以减少padding。

3. collate_fn:自定义样本到批次的转换

collate_fn 的职责是在 DataLoaderDataset 获取到一个样本列表后,将这个列表整理(collate)成一个批次

3.1 默认行为 (default_collate)

如果你不指定 collate_fnDataLoader 会使用 torch.utils.data.default_collate。它的行为是:

  • 它会尝试将输入样本列表中的每个元素(通常是元组,如 (data, label))的对应部分堆叠(stack)起来。
  • 它能处理 PyTorch Tensors, NumPy arrays, Python numbers 和 strings。
  • 对于 Tensors,它会使用 torch.stack 在第0维(批次维)上进行堆叠。
  • 这要求一个批次内的所有样本都有相同的形状。

默认行为失败的场景:
当一个批次内的样本形状不同时,default_collate 会失败。最常见的例子是 NLP 中的变长序列或计算机视觉中的不同尺寸图像。

# 失败的例子
dataset_variable_len = [torch.tensor([1,2,3]), torch.tensor([4,5])]
try:# 默认的 collate_fn 无法处理不同长度的 tensorloader_fail = DataLoader(dataset_variable_len, batch_size=2)for batch in loader_fail:pass
except RuntimeError as e:print(f"默认 collate_fn 失败: {e}")
# 输出: 默认 collate_fn 失败: stack expects each tensor to be equal size, but got [3] at entry 0 and [2] at entry 1
3.2 指定自定义 collate_fn

为了解决上述问题,你可以提供一个自定义的 collate_fn 函数。这个函数接收一个列表,列表中的每个元素都是 Dataset__getitem__ 的返回值。你需要在这个函数里实现将这个列表转换成一个批次的逻辑。

示例:为变长序列实现 padding

这是 collate_fn 最经典的应用。

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequenceclass VariableLengthDataset(Dataset):def __init__(self):self.data = [(torch.tensor([1, 2, 3]), 0),(torch.tensor([4, 5]), 1),(torch.tensor([6, 7, 8, 9]), 0),(torch.tensor([10]), 1)]def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx] # 返回 (sequence, label)def custom_collate_fn(batch):"""自定义的 collate_fn 函数,用于处理变长序列。:param batch: 一个列表,其中每个元素是 Dataset 的 __getitem__ 返回值。例如: [ (tensor([1,2,3]), 0), (tensor([4,5]), 1) ]"""# 1. 将数据和标签分离sequences = [item[0] for item in batch]labels = [item[1] for item in batch]# 2. 对序列进行 padding# pad_sequence 会自动将序列填充到该批次中最长序列的长度# batch_first=True 表示返回的 tensor 形状为 (batch_size, seq_len)padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)# 3. 将标签转换为 Tensorlabels = torch.LongTensor(labels)# 4. 返回处理好的批次return padded_sequences, labels# 创建 DataLoader 并指定自定义的 collate_fn
dataset = VariableLengthDataset()
loader_padded = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=custom_collate_fn)print("\n使用自定义 collate_fn 进行 padding:")
for seq_batch, label_batch in loader_padded:print("Padded Sequences Batch:")print(seq_batch)print("Labels Batch:")print(label_batch)print("-" * 20)# 可能的输出:
# 使用自定义 collate_fn 进行 padding:
# Padded Sequences Batch:
# tensor([[ 6,  7,  8,  9],
#         [10,  0,  0,  0]])
# Labels Batch:
# tensor([0, 1])
# --------------------
# Padded Sequences Batch:
# tensor([[1, 2, 3],
#         [4, 5, 0]])
# Labels Batch:
# tensor([0, 1])
# --------------------

collate_fn 的区别:

  • 它不关心样本的抽取顺序,只关心拿到一批样本后如何组合
  • 它的功能与 sampler 是正交的、互补的。你可以同时使用自定义的 sampler 和自定义的 collate_fn

4. 总结与最佳实践

  1. 简单场景: 如果你只需要顺序或随机打乱数据,并且所有数据样本形状一致,那么直接使用 DataLoadershuffle=True/Falsebatch_size 参数就足够了。

  2. 类别不均衡/特定顺序: 如果你需要解决类别不均衡问题(使用 WeightedRandomSampler)或按特定规则(如先训练简单样本,后训练难样本)抽取数据,那么你需要自定义 sampler,并记得设置 shuffle=False

  3. 优化批次内构成: 如果你想通过将相似长度/大小的样本组合在一起以优化计算效率(例如,减少 NLP 中的 padding 或 CV 中可变尺寸图像的处理开销),那么你需要自定义 batch_sampler。这是最底层的控制,它会覆盖 batch_size, shuffle, sampler 等参数。

  4. 处理可变数据: 如果你的数据样本形状不一(如变长文本、不同尺寸的图片),导致默认的堆叠操作失败,那么你需要自定义 collate_fn。在这个函数里,你可以实现 padding、图像缩放等预处理,将一批异构的样本转换成一个规整的 Tensor 批次。

黄金组合: 在复杂的任务中,你经常会同时使用这些工具。例如,在 NLP 任务中,一个高效的 DataLoader 可能会:

  • 使用一个自定义的 batch_sampler,它首先根据句子长度对所有样本进行粗略分组,然后在每个组内进行随机采样,最后形成批次。这能保证批次内长度相近,同时保留一定的随机性。
  • 使用一个自定义的 collate_fn,它接收 batch_sampler 给出的索引所对应的一批样本,然后对它们进行精确的 padding,并同时处理标签和其它元数据。

通过灵活组合 sampler, batch_samplercollate_fn,你可以为几乎任何数据类型和训练策略构建出高效、定制化的数据加载管道。

http://www.dtcms.com/a/490547.html

相关文章:

  • 怎么做一个网站app吗金华网站建设价格
  • 芷江建设局网站石家庄网站建设公司黄页
  • Excel表----VLOOKUP函数实现两表的姓名、身份证号码、银行卡号核对
  • XMLHttpRequest.responseType:前端获取后端数据的一把“格式钥匙”
  • office便捷办公06:根据相似度去掉excel中的重复行
  • Vue+mockjs+Axios 案例实践
  • http的发展历程
  • Python中使用HTTP 206状态码实现大文件下载的完整指南
  • AngularJS下 $http 上传文件
  • 如何弄死一个网站锡林郭勒盟建设工程造价管理网站
  • 【Node.js】为什么擅长处理 I/O 密集型应用?
  • 基于SpringBoot的无人机飞行管理系统
  • STM32的HardFault错误处理技巧
  • Tekever-固定翼无人机系统:模块化垂直起降、远程海上无人机、战术 ISR 无人机
  • Kafka Queue: 如何严格控制消息数量
  • 大兴建设网站wordpress 托管主机
  • 国外html响应式网站网站开发高级证
  • 苍穹外卖--04--Redis 缓存菜品信息、购物车
  • 大淘客网站如何做seowordpress o2o主题
  • 机器学习催化剂设计专题学习
  • (六)机器学习之图卷积网络
  • 告别刀耕火种:用 Makefile 自动化 C 语言项目编译
  • 【安卓开发】【Android】做一个简单的钢琴模拟器
  • C#控制反转
  • 【Java 开发日记】什么是线程池?它的工作原理?
  • 黄页网站数据来源wordpress 最新漏洞
  • 如何评价3D高斯泼溅(3DGS)技术为数字孪生与实时渲染带来的突破性进展?
  • 技术解析:如何将NX(UG)模型高效转换为3DXML格式
  • 阿里云智能建站网络类黄页
  • SAP MIR7 模拟过账没有这个按钮