当前位置: 首页 > news >正文

【pytorch学习打卡挑战】day3 Pytorch的Dataset与DataLoader及DataLoader源码剖析

前言

本专题致力于学习Pytorch及其相关项目。
参照B站教程
【4、PyTorch的Dataset与DataLoader详细使用教程】
【5、深入剖析PyTorch DataLoader源码】

今日任务

第4、5个视频,主要围绕官方文档来介绍Dataset与DataLoader。
回顾时,建议直接看官方文档:
pytorch官方文档

内容总结

基础知识

PyTorch中的Dataset与DataLoader

PyTorch提供了DatasetDataLoader两个核心类来处理数据加载与批处理。Dataset负责存储样本及其对应的标签(处理单个样本),DataLoader则围绕Dataset进行迭代,提供批量加载、打乱数据和多进程加载等功能。

Dataset是一个抽象类,用户自定义数据集需继承此类并实现__len____getitem__方法。__len__返回数据集大小,__getitem__根据索引返回单个样本。

DataLoader接收一个Dataset对象,并封装了批量加载、多线程等复杂逻辑。常用参数包括batch_size(批大小)、shuffle(是否打乱数据)和num_workers(加载数据的进程数)。

自定义数据读取函数的方法

若要实现自定义数据读取逻辑,需继承Dataset类并重写关键方法。以下是实现步骤:

继承torch.utils.data.Dataset

创建一个新类并继承Dataset,初始化时加载数据路径或其他元信息。例如处理图像分类任务时,可保存图像路径和标签列表。

from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data_path, transform=None):self.data = [...]  # 加载数据路径或元信息self.labels = [...]  # 加载对应标签self.transform = transform  # 数据预处理(如标准化、增强)
实现__len__方法

返回数据集的总样本数,通常为数据列表的长度。

def __len__(self):return len(self.data)
实现__getitem__方法

根据索引返回单个样本和标签。在此方法中完成数据读取(如打开图像文件)和预处理。若提供transform,需在此处调用。

def __getitem__(self, idx):sample = self.data[idx]  # 读取数据(如用PIL打开图像)label = self.labels[idx]if self.transform:sample = self.transform(sample)  # 应用预处理return sample, label
使用DataLoader封装

实例化自定义Dataset后,通过DataLoader实现批量加载和多进程加速。

from torch.utils.data import DataLoaderdataset = CustomDataset(data_path="path/to/data", transform=my_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

关键注意事项

  • 数据预处理:在__getitem__中完成数据读取和转换,避免在__init__中加载全部数据以减少内存占用。
  • 线程安全:若使用多进程(num_workers>0),确保数据读取逻辑是线程安全的。
  • 效率优化:对于大规模数据,建议使用延迟加载(lazy loading)或内存映射技术。

通过上述方法,用户可灵活适配各种数据格式(如文本、音频、视频)和存储结构(如文件夹、数据库)。

DataLoader 的基本结构

PyTorch 的 DataLoader 是用于批量加载数据的工具,支持多进程加载和自定义数据采样。其核心功能包括:

  • 数据集的迭代
  • 批处理(batching)
  • 数据打乱(shuffling)
  • 多进程加载(multiprocessing)

DataLoader 的源码位于 torch/utils/data/dataloader.py,主要由以下几个关键组件构成:

  • DataLoader 类:用户接口,封装了数据加载逻辑。
  • _DataLoaderIter(或 _SingleProcessDataLoaderIter/_MultiProcessingDataLoaderIter):实际的迭代器实现。
  • Sampler:控制数据采样的策略(如顺序采样、随机采样)。
  • Dataset:定义数据集的接口。

DataLoader 的核心实现

DataLoader 类

DataLoader 类是用户直接调用的接口,主要参数包括:

  • dataset:继承自 torch.utils.data.Dataset 的对象。
  • batch_size:批处理大小。
  • shuffle:是否打乱数据。
  • sampler:自定义采样器。
  • num_workers:加载数据的进程数。
  • collate_fn:批处理数据的函数。

初始化时,DataLoader 会根据参数选择单进程或多进程模式:

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None):# 参数校验与初始化逻辑...
迭代器实现

DataLoader 的迭代器分为两种:

  1. 单进程模式 (_SingleProcessDataLoaderIter):直接在主进程加载数据。
  2. 多进程模式 (_MultiProcessingDataLoaderIter):通过 multiprocessing 启动子进程加载数据。

迭代器的核心方法是 __next__,用于获取下一批数据:

def __next__(self):# 从数据队列中获取批数据data = self._get_data()# 处理数据加载异常if self._num_yielded >= len(self):raise StopIterationreturn data

多进程数据加载

多进程模式下,DataLoader 使用 torch.multiprocessing 创建多个工作进程(worker),每个 worker 通过 worker_init_fn 初始化,并通过队列(queue)将数据返回给主进程。

关键流程:

  1. Worker 初始化:每个 worker 加载数据集并通过 collate_fn 处理数据。
  2. 数据预取:worker 提前加载数据到队列中,避免主进程等待。
  3. 异常处理:捕获 worker 中的异常并通过队列传递到主进程。

代码片段:

def _worker_loop(dataset, index_queue, data_queue, collate_fn, ...):# 初始化 workertorch.set_num_threads(1)while True:idx = index_queue.get()if idx is None:breaktry:data = collate_fn([dataset[i] for i in idx])data_queue.put((idx, data))except Exception:# 异常处理data_queue.put((idx, ExceptionWrapper(...)))

Sampler 和 BatchSampler

Sampler 控制数据索引的生成方式:

  • SequentialSampler:顺序采样。
  • RandomSampler:随机采样。
  • BatchSampler:将采样器生成的索引分组为批次。

BatchSampler 的示例:

class BatchSampler(Sampler):def __init__(self, sampler, batch_size, drop_last):self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if not self.drop_last and batch:yield batch

关键优化技术

  1. Pin Memory:将数据加载到固定的 GPU 内存中,加速数据传输。
  2. Prefetching:worker 提前加载下一批数据,减少主进程等待时间。
  3. Shared Memory:在多进程模式下,通过共享内存减少数据复制开销。

总结

DataLoader 的核心设计围绕高效数据加载展开,通过多进程、采样器和批处理机制实现高性能数据流水线。其源码逻辑清晰,模块化设计便于扩展,是 PyTorch 训练流程中的重要组件。

http://www.dtcms.com/a/491275.html

相关文章:

  • 解码Linux文件IO目录检索与文件属性
  • OpenLayers的过滤器 -- 章节三:相交过滤器详解
  • Micro850 控制器网络通信详解:从协议原理到实战配置
  • 六间房2025“逐光之战”海选启幕,歌舞闪耀悬念迭起
  • 把流量的pcap文件转成其他多种类型的数据(比如序列、图片、自然语言的嵌入),迁移其他领域的模型进行训练。
  • 惠济免费网站建设自己怎么建网站
  • 单机让多docker拥有多ip出口
  • 运城网站开发app阿里云最新消息
  • .NET 10深度解析:性能革新与开发生态的全新篇章
  • 国外住宅动态代理smartproxy,爬虫采集利器
  • 国外空间网站源码typecho wordpress比较
  • fineReport_数字转换英文函数
  • 公司网站二维码生成器网站界面设计套题
  • React API
  • 精彩网站制作横栏建设网站
  • 从《楞严经》与六祖惠能:论思想传承中的“不谋而合”
  • 引流软件有哪些网站优化关键词公司
  • 小程序获取
  • html表格,无序,有序,自定义,无语义,表单标签,特殊字符详解
  • 网站排名优化培训xx单位网站建设方案
  • 重庆 建网站做淘客网站企业备案
  • MySQL 8.0事务性数据字典全面解析
  • React高频面试题参考答案
  • 网页制作模板的网站做网站的财务需求
  • 建设银行e路护航官方网站登陆seo网站有优化培训班吗
  • 高可用巡检脚本实战:一键掌握服务、网络、VIP、资源状态
  • 2025全国仿真建模应用挑战赛选题建议与分析
  • 新网站域名备案流程小程序商店开发
  • 惠州网页建站模板网站开发都需要学什么
  • QT-day2,信号和槽