当前位置: 首页 > news >正文

Day 38: Dataset类和DataLoader类

核心概念

在处理大规模数据集时,显存往往无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch提供了两个关键类来解决这个问题:

  1. DataLoader类:决定数据如何加载
  2. Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理

实战演练:MNIST数据集

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)

2. 数据预处理

# 数据预处理管道
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的标准化参数
])

3. 加载MNIST数据集

# 加载训练集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)# 加载测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

🔧 Dataset类详解

Dataset类的核心方法

PyTorch的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

  • __len__():返回数据集的样本总数
  • __getitem__(idx):根据索引idx返回对应样本的数据和标签

魔术方法示例

# __getitem__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30
# __len__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 使用len()函数获取元素数量,这会自动调用__len__方法
my_list_obj = MyList()
print(len(my_list_obj))  # 输出:5

查看单个样本

# 获取一个样本
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
image, label = train_dataset[sample_idx]
print(f"Label: {label}")# 可视化图像
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')plt.show()imshow(image)

DataLoader类详解

DataLoader负责将Dataset中的数据按批次加载,并提供多种数据加载策略:

# 创建训练数据加载器
train_loader = DataLoader(train_dataset,batch_size=64,    # 每个批次64张图片shuffle=True      # 随机打乱数据
)# 创建测试数据加载器
test_loader = DataLoader(test_dataset,batch_size=1000   # 每个批次1000张图片# shuffle=False   # 测试时不需要打乱数据
)

Dataset vs DataLoader 对比

维度DatasetDataLoader
核心职责定义"数据是什么"和"如何获取单个样本"定义"如何批量加载数据"和"加载策略"
核心方法__getitem____len__无自定义方法,通过参数控制
预处理位置__getitem__中通过transform执行无预处理逻辑
并行处理无(仅单样本处理)支持多进程加载
典型参数roottransformbatch_sizeshufflenum_workers

总结

Dataset类的职责

  • 数据内容定义:数据存储路径、读取方式
  • 预处理逻辑:图像变换、数据增强等
  • 返回格式:如(image_tensor, label)

DataLoader类的职责

  • 批量处理:控制batch_size
  • 数据打乱:shuffle参数
  • 并行加载:num_workers参数
  • 内存管理:防止一次性加载过多数据

实用技巧

  1. batch_size选择:通常选择2的幂次方(32、64、128等),这与GPU计算效率相关
  2. 数据预处理时机:在Dataset的__getitem__方法中进行,而不是DataLoader中
  3. 内存优化:DataLoader的num_workers参数可以开启多进程加载,提高效率

@浙大疏锦行

http://www.dtcms.com/a/327420.html

相关文章:

  • 三点估算法(Three-Point Estimation)
  • OpenHarmony介绍
  • 知识篇 | Oracle Active Data Guard(ADG)同步机制再学习
  • TCP服务器网络编程设计流程详解
  • 车规级霍尔开关芯片SC25891 | 为汽车安全带扣筑起高可靠性安全防线
  • FileLink:为企业跨网文件传输筑牢安全与效率基石
  • Go 语言中的结构体、切片与映射:构建高效数据模型的基石
  • apache+虚拟主机
  • windows git安装步骤
  • 深入剖析 React 合成事件:透过 onClick 看本质
  • Flutter UI Kits by Olayemi Garuba:免费开源的高质量UI组件库
  • C++中template、 implicit 、explicit关键字详解
  • Kimi K2 架构深度解析:万亿MoE模型的效率革命与智能体突破
  • Linux随记(二十二)
  • Notta:高效智能的音频转文字工具
  • 视频抽取关键帧算法
  • MR一体机(VST)预算思路
  • Linux的pthread怎么实现的?(包括到汇编层的实现)
  • AWT 事件监听中的适配器模式:从原理到实战的完整指南
  • Photoshop软件打开WebP文件格的操作教程
  • leecode2439 最小化数组中的最大值
  • 大数据中的数据压缩原理
  • 【解决apisix问题】
  • 快速了解词向量模型
  • RIOT、RT-Thread 和 FreeRTOS 是三种主流的实时操作系统
  • SpringMVC的原理及执行流程?
  • Bugku-CTF-web-留言板1
  • Linux网络--2.2、TCP接口
  • PMBT2907A,215 Nxp安世半导体 双极性晶体管 开关电源管理芯片
  • 蚁剑--安装、使用