当前位置: 首页 > news >正文

DAY 38 Dataset和Dataloader类 - 2025.10. 2

Dataset和Dataloader类

知识点回顾:

  1. Dataset类的__getitem____len__方法(本质是python的特殊方法)
  2. Dataloader
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

笔记:

一、Dataset 类:数据加载的 “基础模板”

Dataset是 PyTorch 中抽象的数据集基类,作用是 “将原始数据封装成可索引的样本集合”(每个样本包含 “数据 + 标签”)。它强制要求子类实现两个Python 特殊方法,否则无法正常使用。

  1. 核心方法 1:__len__() —— 告诉数据集 “有多大”
  • 作用:返回数据集的总样本数量,让外界知道 “数据集的规模”(比如训练集有 60000 张图)。
  • 触发时机:当调用len(dataset)时自动执行(Python 特殊方法的特性)。
  • 必须实现:若不实现,调用len()会报错。
  1. 核心方法 2:__getitem__(idx) —— 告诉数据集 “如何取样本”
  • 作用:根据输入的索引idx(整数),返回对应位置的单个样本(通常是(数据, 标签)的元组)。
  • 触发时机:当调用dataset[idx]时自动执行(比如dataset[0]取第 1 个样本)。
  • 必须实现:这是Dataset的核心,没有它就无法通过索引获取数据,后续Dataloader也无法批量加载。
  1. 直观示例:自定义一个简单 Dataset
    用 “模拟的数字数据” 自定义 Dataset,理解两个方法的实现逻辑:
from torch.utils.data import Dataset  # 导入基类# 自定义Dataset,继承自Dataset基类
class MySimpleDataset(Dataset):def __init__(self, data_list, label_list):# 初始化:保存原始数据和标签(构造函数,用户传入数据)self.data = data_list    # 假设data_list是[1.2, 3.4, 5.6, 7.8](模拟4个样本)self.labels = label_list # 假设label_list是[0, 1, 0, 1](对应4个样本的标签)def __len__(self):# 返回总样本数:数据和标签的长度一致,取其一即可return len(self.data)  # 这里返回4def __getitem__(self, idx):# 根据索引idx取单个样本single_data = self.data[idx]   # 取第idx个数据(如idx=0时取1.2)single_label = self.labels[idx]# 取第idx个标签(如idx=0时取0)return single_data, single_label  # 返回(数据,标签)元组# 测试自定义Dataset
data = [1.2, 3.4, 5.6, 7.8]
labels = [0, 1, 0, 1]
my_dataset = MySimpleDataset(data, labels)# 调用__len__():输出4
print(f"数据集总样本数:{len(my_dataset)}")
# 调用__getitem__(2):输出(5.6, 0)(第3个样本)
print(f"第3个样本(idx=2):{my_dataset[2]}")

二、Dataloader 类:数据加载的 “增强工具”

Dataset解决了 “单样本索引访问”,但实际训练中需要批量加载、打乱数据、多线程加速等功能 ——Dataloader就是干这个的,它是Dataset的 “上层封装”,负责 “高效地为模型喂数据”。
1. 核心作用

  • 批量加载:将Dataset的单样本按batch_size打包成 “批量数据”(比如一次加载 32 个样本)。
  • 打乱数据:训练时打乱样本顺序,避免模型 “记忆顺序”,提升泛化能力。
  • 多线程加速:用num_workers开启多线程,并行读取数据,减少模型等待时间。
  • 自动拼接:将多个单样本的 Tensor 自动拼接成批量 Tensor(如 32 个 28×28 的图像→32×28×28)。
  1. 核心参数(必掌握)
参数名作用说明
dataset关联的Dataset对象(必须传,Dataloader从这个Dataset中取数据)
batch_size每次加载的样本数量(如32、64,根据显卡显存调整)
shuffle是否打乱数据(True:训练时用;False:测试时用,保证结果可复现)
num_workers用于加载数据的线程数(0:默认,主线程加载;建议设为48,根据 CPU 调整)
drop_last若数据集总样本数不能被batch_size整除,是否丢弃最后一个不完整的批次(通常False

3. 实操示例:用 Dataloader 加载自定义 Dataset
承接上面的MySimpleDataset,用Dataloader批量加载:

from torch.utils.data import DataLoader# 1. 先有Dataset对象(之前定义的my_dataset)
# 2. 构建Dataloader
my_dataloader = DataLoader(dataset=my_dataset,    # 关联的Datasetbatch_size=2,          # 每次加载2个样本shuffle=True,          # 打乱顺序(模拟训练场景)num_workers=0          # 主线程加载(简单场景用0,避免多线程报错)
)# 3. 迭代Dataloader获取批量数据(模型训练时就是这么循环取数据的)
print("批量加载的数据:")
for batch_idx, (batch_data, batch_labels) in enumerate(my_dataloader):print(f"第{batch_idx+1}批:")print(f"  批量数据:{batch_data}")  # 2个样本的数组合并(如tensor([3.4, 7.8]))print(f"  批量标签:{batch_labels}")# 2个样本的标签合并(如tensor([1, 1]))

输出说明:因为shuffle=True,每次运行的批次顺序可能不同,但每个批次都有 2 个样本(符合batch_size=2)。

三、MNIST 手写数据集:经典的 “实战案例”

MNIST 是手写数字分类数据集,是计算机视觉入门的 “hello world”,完美适配DatasetDataloader的使用场景。

1. 数据集核心信息

特性说明
内容0-9 的手写数字图像(共 10 个类别)
样本数量训练集:60000 张;测试集:10000 张
图像尺寸28×28 像素,灰度图(1 个通道,像素值 0-255,0 表示黑色,255 表示白色)
标签0-9 的整数(对应图像中的数字)
加载方式通过torchvision.datasets.MNIST直接加

2. 实操:用 Dataset+Dataloader 加载 MNIST
结合前两个知识点,完整流程如下(包含图像显示,直观理解数据):

import torchvision.datasets as datasets  # 包含MNIST Dataset
import torchvision.transforms as transforms  # 图像预处理
import matplotlib.pyplot as plt  # 显示图像# 1. 定义图像预处理(MNIST原始是PIL图像,需转为Tensor)
transform = transforms.Compose([transforms.ToTensor(),  # 转为Tensor:① PIL→Tensor;② 像素值从0-255归一化到0-1;③ 维度从H×W→C×H×W(1×28×28)
])# 2. 加载MNIST训练集(Dataset对象)
mnist_dataset = datasets.MNIST(root="./data",          # 数据集保存路径(不存在则自动创建)train=True,            # True:加载训练集;False:加载测试集download=True,         # True:自动下载数据集(首次运行需联网,约11MB)transform=transform    # 应用预处理
)# 3. 用Dataloader批量加载
mnist_dataloader = DataLoader(dataset=mnist_dataset,batch_size=32,         # 一次加载32张图shuffle=True,          # 训练时打乱num_workers=0
)# 4. 验证知识点:调用Dataset的__len__和__getitem__
print(f"MNIST训练集总样本数:{len(mnist_dataset)}")  # 输出60000
single_img, single_label = mnist_dataset[100]  # 取第101个样本
print(f"单张图像形状(C×H×W):{single_img.shape}")  # 输出torch.Size([1, 28, 28])
print(f"单张图像标签:{single_label}")  # 输出该图像对应的数字(如5)# 5. 显示单张MNIST图像(需调整维度:C×H×W→H×W)
plt.figure(figsize=(3, 3))
# permute(1,2,0):交换维度(1=H,2=W,0=C),灰度图可省略最后一维
plt.imshow(single_img.permute(1, 2, 0).squeeze(), cmap="gray")  # squeeze()去掉通道维度
plt.title(f"MNIST Digit: {single_label}")
plt.axis("off")
plt.show()# 6. 迭代Dataloader获取批量数据(模拟模型训练)
for batch_imgs, batch_labels in mnist_dataloader:print(f"\n批量图像形状(batch_size×C×H×W):{batch_imgs.shape}")  # 输出torch.Size([32, 1, 28, 28])print(f"批量标签形状(batch_size):{batch_labels.shape}")        # 输出torch.Size([32])break  # 只看第一批,避免循环太多

四、知识点联动总结

三个知识点的逻辑关系可以用一句话概括:
MNIST是 “数据来源”,DatasetMNIST封装成 “可索引的样本集合”,Dataloader再将Dataset封装成 “适合模型训练的批量数据流”。

具体流程
原始 MNIST 文件 → MNIST(Dataset)实现__len__/__getitem__)→ Dataloader(批量 / 打乱 / 多线程)→ 模型训练(每次接收一个批次的数据)。

作业

# 1. 导入必要的库
import torchvision.datasets as datasets  # 包含CIFAR数据集的Dataset类
import torchvision.transforms as transforms  # 图像预处理工具
import matplotlib.pyplot as plt  # 显示图片
import numpy as np  # 辅助图像格式转换# 2. 定义图像预处理(关键:将CIFAR的原始格式转为可显示的图像对象)
# transforms.Compose:组合多个预处理操作
transform = transforms.Compose([transforms.ToTensor(),  # 把PIL图像(0-255)转为Tensor(0-1,维度:C×H×W,即通道×高度×宽度)# 可选:如需还原为0-255范围,可加transforms.Lambda(lambda x: x*255)
])# 3. 加载CIFAR-10数据集(使用Dataset类)
# root:数据集保存路径(不存在则自动创建)
# train=True:加载训练集;train=False:加载测试集
# download=True:自动下载数据集(首次运行需联网,约160MB)
# transform:应用上述预处理
cifar10_dataset = datasets.CIFAR10(root='./data',  # 数据集存在当前目录的data文件夹下train=True,download=True,transform=transform
)# 4. 查看Dataset核心方法(呼应知识点回顾)
print(f"数据集总长度(__len__方法):{len(cifar10_dataset)}")  # 输出50000(CIFAR-10训练集数量)# 5. 获取单张图片和对应的标签(使用__getitem__方法)
# 索引可任意选(0~49999,训练集;0~9999,测试集)
index = 123  # 选第124张图片(索引从0开始)
img_tensor, label = cifar10_dataset[index]  # __getitem__返回(图像Tensor,标签)# 查看获取的内容
print(f"\n图像Tensor形状(C×H×W):{img_tensor.shape}")  # 输出torch.Size([3, 32, 32])(3通道、32×32)
print(f"图片标签(数字):{label}")  # 输出0~9的数字(对应CIFAR-10的10个类别)# 6. 定义CIFAR-10的类别名称(将数字标签映射为具体类别)
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'
]
print(f"图片类别(名称):{cifar10_classes[label]}")  # 输出标签对应的类别名称# 7. 显示图片(需将Tensor格式转为matplotlib支持的格式)
# 步骤:① 调整维度(C×H×W → H×W×C);② 转为numpy数组;③ 若Tensor是0-1范围,转回0-255(可选,更清晰)
img_show = img_tensor.permute(1, 2, 0).numpy()  # permute:交换维度(1=H,2=W,0=C)
img_show = img_show * 255  # 从0-1范围转回0-255(因ToTensor默认归一化到0-1)
img_show = img_show.astype(np.uint8)  # 转为8位整数(图像像素的标准格式)# 用matplotlib显示
plt.figure(figsize=(4, 4))  # 设置图片大小
plt.imshow(img_show)  # 显示图像
plt.title(f"CIFAR-10 Image: {cifar10_classes[label]} (Index: {index})")  # 标题:类别+索引
plt.axis('off')  # 隐藏坐标轴
plt.show()  # 弹出显示窗口

Duplicate key in file WindowsPath(‘d:/Anaconda/envs/DL39/lib/site-packages/matplotlib/mpl-data/matplotlibrc’), line 263 (‘font.family: simhei’)
数据集总长度(__len__方法):50000

图像Tensor形状(C×H×W):torch.Size([3, 32, 32])
图片标签(数字):2
图片类别(名称):bird
在这里插入图片描述

@浙大疏锦行

http://www.dtcms.com/a/435246.html

相关文章:

  • Privacy Eraser(隐私保护软件)多语便携版
  • C4D R20新增功能概述及体积对象SDF类型深度解析
  • 上海做网站公司推荐简单网上书店网站建设php
  • HarmonyOS应用开发深度解析:ArkTS语法精要与UI组件实践
  • 北京示范校建设网站wordpress快速发布
  • 常用网站布局土巴兔这种网站怎么做
  • toLua[四] Examples 03_CallLuaFunction分析
  • 建设景区网站推文企业网站排名怎么优化
  • 汽车信息安全测试与ISO/SAE 21434标准
  • Hadoop HA 集群安装配置
  • 10.2总结
  • 旅游网站建设最重要的流程如何制作公众号教程
  • 淄博建设局网站秀堂h5官网
  • 【动态规划DP:纸币硬币专题】P2834 纸币问题 3
  • springbatch使用记录
  • 平面设计师网站都有哪些网站突然被降权怎么办
  • 前向传播与反向传播(附视频链接)
  • 广州建设工程造价管理站橙色网站欣赏
  • ipv6之6to4配置案例
  • 太仓有专门做网站的地方吗沧州企业网站专业定制
  • gRPC从0到1系列【14】
  • JVM的内存分配策略有哪些?
  • 卡特兰数【模板】(四个公式模板)
  • Process Monitor 学习笔记(5.5):保存/打开追踪记录——复盘、复现与分享的正确姿势
  • 【机器学习宝藏】深入解析经典人脸识别数据集:Olivetti Faces
  • 【C++】深入理解红黑树:概念、性质和实现
  • 制作卖东西网站玩具网站 下载
  • 网站建设培训课程wordpress描述插件
  • php网站超市源码下载十大永久免费crm
  • 网站色彩代码carousel wordpress