当前位置: 首页 > news >正文

Pytorch的Dataloader使用详解

PyTorch 的 DataLoader 是数据加载的核心组件,它能高效地批量加载数据并进行预处理。

Pytorch DataLoader基础概念

DataLoader基础概念
DataLoader是PyTorch基础概念
DataLoader是PyTorch中用于加载数据的工具,它可以:批量加载数据(batch loading)打乱数据(shuffling)并行加载数据(多线程)
自定义数据加载方式Dataloader的基本使用from torch.utils.data import Dataset, DataLoader

自定义数据集类

class MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __getitem__(self, index):return self.data[index], self.labels[index]def __len__(self):return len(self.data)

创建数据集实例

dataset = MyDataset(data, labels)

创建DataLoader

dataloader = DataLoader(dataset=dataset,      # 数据集batch_size=32,        # 批次大小shuffle=True,         # 是否打乱数据num_workers=4,        # 多进程加载数据的线程数drop_last=False       # 当样本数不能被batch_size整除时,是否丢弃最后一个不完整的batch
)
# 使用DataLoader迭代数据
for batch_data, batch_labels in dataloader:# 训练或推理代码pass

DataLoader重要参数详解

  1. dataset: 要加载的数据集,必须是Dataset类的实例 batch_size: 每个批次的样本数
  2. shuffle:是否在每个epoch重新打乱数据
  3. sampler:自定义从数据集中抽取样本的策略,如果指定了sampler,则shuffle必须为False
  4. num_workers:使用多少个子进程加载数据,0表示在主进程中加载。
  5. collate_fn:将一批数据整合成一个批次的函数,特别使用于处理不同长度的序列数据
  6. Pin_memory:如果为True,数据加载器会将张量复制到CUDA固定内存中,加速CPU到GPU的数据传输
  7. drop_last: 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的批次。
  8. timeout:收集一个批次的超时值
  9. worker_init_fn:每个worker初始化时被调用的函数
  10. weight_sampler:参数决定是都使用加权采样器来平衡类别分布
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class
这段代码决定了如何创建数据加载器,根据infinite_data_loader参数选择不同的加载器类型:
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class

代码解析

这段代码基于infinite_data_loader参数创建不同类型的数据加载器:
当infinite_data_loader为True时:
创建InfiniteDataLoader实例
自定义的无限循环数据加载器,会持续提供数据而不会在一个epoch结束时停止
当infinite_data_loader为False时:
创建标准的PyTorch DataLoader实例
这是普通的数据加载器,一个epoch结束后会停止

共同参数:

dataset=data:要加载的数据集
batch_size=batch_size:每批数据的大小
shuffle=shuffle:是否打乱数据(之前代码中已设置)
num_workers=num_workers:用于并行加载数据的线程数
sampler=sampler:用于采样的策略(之前代码中已设置,可能是加权采样器)
**kwargs:其他可能的参数,如pin_memory、drop_last等

返回值:

data_loader:创建好的数据加载器
n_class = len(data.classes):数据集中的类别数量
InfiniteDataLoader的作用
在您的代码中定义了两种InfiniteDataLoader实现:一种作为DataLoader的子类,另一种是完全自定义的类。它们的共同目的是:
持续提供数据:当一个epoch结束后,自动重新开始,不会引发StopIteration异常
支持长时间训练:在需要长时间训练的场景中特别有用,如半监督学习或者领域适应
避免手动重置:不需要在每个epoch结束后手动重置数据加载器

使用场景

无限数据加载器特别适用于:
持续训练:模型需要无限期地训练,如自监督学习或强化学习
不均匀更新:源域和目标域数据需要不同频率的更新
流式训练:数据以流的形式到达,不需要明确的epoch边界
基于迭代而非epoch的训练:训练基于迭代次数而非数据epoch
最后的返回值n_class提供了数据集的类别数量,这对模型构建和评估都很重要,比如设置分类层的输出维度或计算平均类别准确率。
高级用法

1.自定义collate_fn处理变长序列

def collate_fn(batch):# 排序批次数据,按序列长度降序batch.sort(key=lambda x: len(x[0]), reverse=True)# 分离数据和标签sequences, labels = zip(*batch)# 计算每个序列的长度lengths = [len(seq) for seq in sequences]# 填充序列到相同长度padded_seqs = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)return padded_seqs, torch.tensor(labels), lengths

使用自定义的collate_fn

dataloader = DataLoader(dataset=text_dataset,batch_size=16,shuffle=True,collate_fn=collate_fn
)

2.使用Sampler进行不均衡数据采样
from torch.utils.data import WeightedRandomSampler

假设我们有类别不平衡问题,计算采样权重

class_count = [100, 1000, 500]  # 每个类别的样本数量
weights = 1.0 / torch.tensor(class_count, dtype=torch.float)
sample_weights = weights[target_list]  # target_list是每个样本的类别索引

创建WeightedRandomSampler

sampler = WeightedRandomSampler(weights=sample_weights,num_samples=len(sample_weights),replacement=True
)

使用sampler

dataloader = DataLoader(dataset=dataset,batch_size=32,sampler=sampler,  # 使用sampler时,shuffle必须为Falsenum_workers=4
)

相关文章:

  • 【USRP】在linux下安装python API调用
  • Oracle 中的虚拟列Virtual Columns和PostgreSQL Generated Columns生成列
  • 一分钟了解大语言模型(LLMs)
  • 基于ssm+mysql的高校设备管理系统(含LW+PPT+源码+系统演示视频+安装说明)
  • 音频分类的学习
  • De-biased Attention Supervision for Text Classifcation with Causality
  • 学习51单片机01(安装开发环境)
  • 基于Matlab的非线性Newmark法用于计算结构动力响应
  • STM32 之网口资源
  • 当 DeepSeek 遇见区块链:一场颠覆式的应用革命
  • 学习黑客蓝牙技术详解
  • SAP Fiori Elements Object Page
  • rocketmq 拉取消息
  • AI智能体 | 使用Coze一键制作“假如书籍会说话”视频,18个作品狂吸17.6万粉,读书博主新标杆!(附保姆级教程)
  • 输入一个正整数,将其各位数字倒序输出(如输入123,输出321)
  • 【行为型之模板方法模式】游戏开发实战——Unity标准化流程与可扩展架构的核心实现
  • Prometheus 的介绍与部署(入门)
  • 第二章 变量和运算符
  • git push 报错:send-pack: unexpected disconnect while reading sideband packet
  • c#队列及其操作
  • 透视社会组织创新实践中的花开岭现象:与乡村发展的融合共进
  • 商务部新闻发言人就暂停17家美国实体不可靠实体清单措施答记者问
  • 脑血管支架:救命神器还是定时炸弹?听听医生的大实话
  • 警方通报:某博主遭勒索后自杀系自导自演,已立案调查
  • 广西壮族自治区党委政法委副书记李文博接受审查调查
  • 中拉论坛第四届部长级会议将举行,外交部介绍情况