当前位置: 首页 > news >正文

PyTorch 深度学习基础:Dataset 与 DataLoader 详解

文章目录

    • 一、为什么需要 Dataset 和 DataLoader?
    • 二、Dataset:数据集的抽象类
      • 自定义 Dataset 示例
    • 三、DataLoader:数据批处理与加载工具
    • 四、Dataset + DataLoader 完整示例
    • 五、DataLoader 的常用参数
    • 六、Dataset 与 DataLoader 的关系图
    • 七、总结


一、为什么需要 Dataset 和 DataLoader?

在深度学习中通常需要完成以下任务:

  1. 从文件中读取数据(图片、文本、CSV等)
  2. 进行预处理或数据增强
  3. 每次训练读取一个批次(batch)数据
  4. 打乱(shuffle)或并行加载数据以加快训练

如果手动写这些逻辑,代码会变得冗长且不易维护。

在使用 PyTorch 进行深度学习训练时,数据加载是不可或缺的一步。
PyTorch 为此提供了两个抽象接口来标准化数据加载流程

组件作用
Dataset定义数据集的内容与读取方式
DataLoader负责批量加载、打乱、并行读取数据

PyTorch 提供了两个核心组件来处理数据集:

  • torch.utils.data.Dataset
  • torch.utils.data.DataLoader

这两者一起构成了 PyTorch 的数据加载体系


二、Dataset:数据集的抽象类

Dataset 是 PyTorch 中所有数据集的基类。
它定义了两个必须实现的方法:

__len__()   # 返回数据集的样本总数
__getitem__(index)  # 根据索引返回一个样本(数据 + 标签)

自定义 Dataset 示例

假设我们有一个简单的 CSV 文件,存储了输入特征和标签:

x1, x2, y
1.0, 2.0, 0
2.0, 3.0, 1
3.0, 4.0, 0

我们可以这样定义自定义数据集:

import torch
from torch.utils.data import Dataset
import pandas as pdclass MyDataset(Dataset):def __init__(self, csv_file):self.data = pd.read_csv(csv_file)def __len__(self):return len(self.data)def __getitem__(self, idx):x = self.data.iloc[idx, :-1].valuesy = self.data.iloc[idx, -1]return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

使用方法:

dataset = MyDataset("data.csv")
print(len(dataset))        # 数据样本数量
print(dataset[0])          # 第一个样本 (x, y)

💡 提示:
PyTorch 已经为常用数据集(如 MNIST、CIFAR10、ImageNet 等)提供了内置实现,
也可以直接使用,例如:

from torchvision import datasets, transformstransform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

三、DataLoader:数据批处理与加载工具

DataLoader 是一个高效的数据加载器,用来按批次读取 Dataset 中的数据。
它可以自动:

  • 按批加载(batch)
  • 打乱数据(shuffle)
  • 并行加载数据(num_workers)

基本用法

from torch.utils.data import DataLoadertrain_loader = DataLoader(dataset, batch_size=4, shuffle=True)

然后可以像这样迭代读取数据:

for batch_idx, (x, y) in enumerate(train_loader):print(batch_idx, x, y)

关于Python enumerate()函数用法,可参考Python 基础详解:enumerate() 函数-CSDN博客

输出示例:

0 tensor([[1., 2.], [3., 4.], [2., 3.], [4., 5.]]) tensor([0, 0, 1, 1])
1 tensor([[5., 6.], [7., 8.], [6., 7.], [8., 9.]]) tensor([1, 0, 1, 0])

四、Dataset + DataLoader 完整示例

以下是一个完整的可运行例子:

import torch
from torch.utils.data import Dataset, DataLoader# 1. 自定义数据集
class MyDataset(Dataset):def __init__(self):self.data = torch.arange(10)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.data[idx] ** 2  # 返回 x 和 x²# 2. 创建数据集
dataset = MyDataset()# 3. 创建 DataLoader
loader = DataLoader(dataset, batch_size=3, shuffle=True)# 4. 迭代读取
for x, y in loader:print(f"x: {x}, y: {y}")

输出示例:

x: tensor([7, 6, 2]), y: tensor([49, 36,  4])
x: tensor([4, 5, 1]), y: tensor([16, 25,  1])
x: tensor([8, 3, 0]), y: tensor([64,  9,  0])
x: tensor([9]), y: tensor([81])

可以看到:

  • DataLoader 自动把样本打乱
  • 每次取 3 个数据(batch size=3)
  • 返回的是 (x_batch, y_batch)

五、DataLoader 的常用参数

参数作用默认值
dataset数据集对象必须
batch_size每个批次的样本数1
shuffle是否打乱顺序False
num_workers并行加载数据的线程数0
(Windows 推荐 0)
drop_last是否丢弃最后一个不完整批次False
pin_memory是否将张量拷贝到 CUDA 固定内存中(提高 GPU 速度)False

示例:

train_loader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=2,drop_last=True,pin_memory=True
)

六、Dataset 与 DataLoader 的关系图

Dataset 负责“是什么数据”,DataLoader 负责“怎么取数据”。


七、总结

组件功能是否必须实现
Dataset.__len__()返回样本数量✅ 是
Dataset.__getitem__()返回单个样本✅ 是
DataLoader负责批次加载、打乱、并行等✅ 推荐使用
http://www.dtcms.com/a/529325.html

相关文章:

  • 2.4寸SPI串口ILI9341芯片彩色LCD驱动
  • 绍兴企业做网站浙江建设信息港电工证查询
  • 【系统分析师】高分论文:论需求分析及其应用(ERP 财务管控项目)
  • 数据结构(9)
  • 怎么做点播网站唐山企业做网站
  • 网站建设迅雷wordpress 简洁文章主题
  • 成都网站建设好多钱中英版网站怎么做
  • wait和notify机制详解
  • 网站开发文档需求撰写word营销型网站建站系统
  • wordpress order插件seo实训报告
  • 南宁建设厅网站是什么品牌网络市场环境调研报告
  • 做外贸需要做网站吗电子商务网站建设读书笔记
  • Linux17 进程间的通信 消息队列
  • 从WSL安装到初始化buildozer全过程~
  • 点击网站排名西南网架公司
  • 专做宠物的网站注册一个5000万空壳公司要多少钱
  • 长春火车站进站需要核酸检测吗豆瓣 wordpress
  • 【Java 序列化 (Serialization)】
  • STM32H743-ARM例程30-Modbus
  • ps网站导航怎么做wordpress 主题详解
  • 网站建设全网推广小程序网站制作app排行榜前十名
  • 正规网站建设多少费用深圳品牌设计公司哪家好
  • Product Hunt 每日热榜 | 2025-10-25
  • Java实用工具库深度解析:从生产力到艺术性
  • 全网营销网站建设特点南山出名的互联网公司
  • 计算机组成原理C,存储器容量计算地址线和数据线
  • 连云港建设局官方网站模板大全免费
  • 建设项目经济评价网站青岛公司网站建设价格
  • 重庆网站seo营销模板做网站怎么挣钱
  • 软件设计师知识点总结:软件工程