DAY 38 Dataset和Dataloader类 - 2025.10. 2
Dataset和Dataloader类
知识点回顾:
Dataset
类的__getitem__
和__len__
方法(本质是python的特殊方法)Dataloader
类minist
手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
笔记:
一、Dataset 类:数据加载的 “基础模板”
Dataset
是 PyTorch 中抽象的数据集基类,作用是 “将原始数据封装成可索引的样本集合”(每个样本包含 “数据 + 标签”)。它强制要求子类实现两个Python 特殊方法,否则无法正常使用。
- 核心方法 1:
__len__()
—— 告诉数据集 “有多大”
- 作用:返回数据集的总样本数量,让外界知道 “数据集的规模”(比如训练集有 60000 张图)。
- 触发时机:当调用
len(dataset)
时自动执行(Python 特殊方法的特性)。 - 必须实现:若不实现,调用
len()
会报错。
- 核心方法 2:
__getitem__(idx)
—— 告诉数据集 “如何取样本”
- 作用:根据输入的索引
idx
(整数),返回对应位置的单个样本(通常是(数据, 标签
)的元组)。 - 触发时机:当调用
dataset[idx]
时自动执行(比如dataset[0]
取第 1 个样本)。 - 必须实现:这是
Dataset
的核心,没有它就无法通过索引获取数据,后续Dataloader
也无法批量加载。
- 直观示例:自定义一个简单 Dataset
用 “模拟的数字数据” 自定义 Dataset,理解两个方法的实现逻辑:
from torch.utils.data import Dataset # 导入基类# 自定义Dataset,继承自Dataset基类
class MySimpleDataset(Dataset):def __init__(self, data_list, label_list):# 初始化:保存原始数据和标签(构造函数,用户传入数据)self.data = data_list # 假设data_list是[1.2, 3.4, 5.6, 7.8](模拟4个样本)self.labels = label_list # 假设label_list是[0, 1, 0, 1](对应4个样本的标签)def __len__(self):# 返回总样本数:数据和标签的长度一致,取其一即可return len(self.data) # 这里返回4def __getitem__(self, idx):# 根据索引idx取单个样本single_data = self.data[idx] # 取第idx个数据(如idx=0时取1.2)single_label = self.labels[idx]# 取第idx个标签(如idx=0时取0)return single_data, single_label # 返回(数据,标签)元组# 测试自定义Dataset
data = [1.2, 3.4, 5.6, 7.8]
labels = [0, 1, 0, 1]
my_dataset = MySimpleDataset(data, labels)# 调用__len__():输出4
print(f"数据集总样本数:{len(my_dataset)}")
# 调用__getitem__(2):输出(5.6, 0)(第3个样本)
print(f"第3个样本(idx=2):{my_dataset[2]}")
二、Dataloader 类:数据加载的 “增强工具”
Dataset
解决了 “单样本索引访问”,但实际训练中需要批量加载、打乱数据、多线程加速等功能 ——Dataloader
就是干这个的,它是Dataset
的 “上层封装”,负责 “高效地为模型喂数据”。
1. 核心作用
- 批量加载:将
Dataset
的单样本按batch_size
打包成 “批量数据”(比如一次加载 32 个样本)。 - 打乱数据:训练时打乱样本顺序,避免模型 “记忆顺序”,提升泛化能力。
- 多线程加速:用
num_workers
开启多线程,并行读取数据,减少模型等待时间。 - 自动拼接:将多个单样本的 Tensor 自动拼接成批量 Tensor(如 32 个 28×28 的图像→32×28×28)。
- 核心参数(必掌握)
参数名 | 作用说明 |
---|---|
dataset | 关联的Dataset 对象(必须传,Dataloader 从这个Dataset 中取数据) |
batch_size | 每次加载的样本数量(如32、64,根据显卡显存调整) |
shuffle | 是否打乱数据(True :训练时用;False :测试时用,保证结果可复现) |
num_workers | 用于加载数据的线程数(0 :默认,主线程加载;建议设为4 、8 ,根据 CPU 调整) |
drop_last | 若数据集总样本数不能被batch_size 整除,是否丢弃最后一个不完整的批次(通常False ) |
3. 实操示例:用 Dataloader 加载自定义 Dataset
承接上面的MySimpleDataset
,用Dataloader
批量加载:
from torch.utils.data import DataLoader# 1. 先有Dataset对象(之前定义的my_dataset)
# 2. 构建Dataloader
my_dataloader = DataLoader(dataset=my_dataset, # 关联的Datasetbatch_size=2, # 每次加载2个样本shuffle=True, # 打乱顺序(模拟训练场景)num_workers=0 # 主线程加载(简单场景用0,避免多线程报错)
)# 3. 迭代Dataloader获取批量数据(模型训练时就是这么循环取数据的)
print("批量加载的数据:")
for batch_idx, (batch_data, batch_labels) in enumerate(my_dataloader):print(f"第{batch_idx+1}批:")print(f" 批量数据:{batch_data}") # 2个样本的数组合并(如tensor([3.4, 7.8]))print(f" 批量标签:{batch_labels}")# 2个样本的标签合并(如tensor([1, 1]))
输出说明:因为shuffle=True
,每次运行的批次顺序可能不同,但每个批次都有 2 个样本(符合batch_size=2
)。
三、MNIST 手写数据集:经典的 “实战案例”
MNIST 是手写数字分类数据集,是计算机视觉入门的 “hello world”,完美适配Dataset
和Dataloader
的使用场景。
1. 数据集核心信息
特性 | 说明 |
---|---|
内容 | 0-9 的手写数字图像(共 10 个类别) |
样本数量 | 训练集:60000 张;测试集:10000 张 |
图像尺寸 | 28×28 像素,灰度图(1 个通道,像素值 0-255,0 表示黑色,255 表示白色) |
标签 | 0-9 的整数(对应图像中的数字) |
加载方式 | 通过torchvision.datasets.MNIST 直接加 |
2. 实操:用 Dataset+Dataloader 加载 MNIST
结合前两个知识点,完整流程如下(包含图像显示,直观理解数据):
import torchvision.datasets as datasets # 包含MNIST Dataset
import torchvision.transforms as transforms # 图像预处理
import matplotlib.pyplot as plt # 显示图像# 1. 定义图像预处理(MNIST原始是PIL图像,需转为Tensor)
transform = transforms.Compose([transforms.ToTensor(), # 转为Tensor:① PIL→Tensor;② 像素值从0-255归一化到0-1;③ 维度从H×W→C×H×W(1×28×28)
])# 2. 加载MNIST训练集(Dataset对象)
mnist_dataset = datasets.MNIST(root="./data", # 数据集保存路径(不存在则自动创建)train=True, # True:加载训练集;False:加载测试集download=True, # True:自动下载数据集(首次运行需联网,约11MB)transform=transform # 应用预处理
)# 3. 用Dataloader批量加载
mnist_dataloader = DataLoader(dataset=mnist_dataset,batch_size=32, # 一次加载32张图shuffle=True, # 训练时打乱num_workers=0
)# 4. 验证知识点:调用Dataset的__len__和__getitem__
print(f"MNIST训练集总样本数:{len(mnist_dataset)}") # 输出60000
single_img, single_label = mnist_dataset[100] # 取第101个样本
print(f"单张图像形状(C×H×W):{single_img.shape}") # 输出torch.Size([1, 28, 28])
print(f"单张图像标签:{single_label}") # 输出该图像对应的数字(如5)# 5. 显示单张MNIST图像(需调整维度:C×H×W→H×W)
plt.figure(figsize=(3, 3))
# permute(1,2,0):交换维度(1=H,2=W,0=C),灰度图可省略最后一维
plt.imshow(single_img.permute(1, 2, 0).squeeze(), cmap="gray") # squeeze()去掉通道维度
plt.title(f"MNIST Digit: {single_label}")
plt.axis("off")
plt.show()# 6. 迭代Dataloader获取批量数据(模拟模型训练)
for batch_imgs, batch_labels in mnist_dataloader:print(f"\n批量图像形状(batch_size×C×H×W):{batch_imgs.shape}") # 输出torch.Size([32, 1, 28, 28])print(f"批量标签形状(batch_size):{batch_labels.shape}") # 输出torch.Size([32])break # 只看第一批,避免循环太多
四、知识点联动总结
三个知识点的逻辑关系可以用一句话概括:
MNIST
是 “数据来源”,Dataset
将MNIST
封装成 “可索引的样本集合”,Dataloader
再将Dataset
封装成 “适合模型训练的批量数据流”。
具体流程:
原始 MNIST 文件 → MNIST(Dataset)
实现__len__
/__getitem__
)→ Dataloader
(批量 / 打乱 / 多线程)→ 模型训练(每次接收一个批次的数据)。
作业
# 1. 导入必要的库
import torchvision.datasets as datasets # 包含CIFAR数据集的Dataset类
import torchvision.transforms as transforms # 图像预处理工具
import matplotlib.pyplot as plt # 显示图片
import numpy as np # 辅助图像格式转换# 2. 定义图像预处理(关键:将CIFAR的原始格式转为可显示的图像对象)
# transforms.Compose:组合多个预处理操作
transform = transforms.Compose([transforms.ToTensor(), # 把PIL图像(0-255)转为Tensor(0-1,维度:C×H×W,即通道×高度×宽度)# 可选:如需还原为0-255范围,可加transforms.Lambda(lambda x: x*255)
])# 3. 加载CIFAR-10数据集(使用Dataset类)
# root:数据集保存路径(不存在则自动创建)
# train=True:加载训练集;train=False:加载测试集
# download=True:自动下载数据集(首次运行需联网,约160MB)
# transform:应用上述预处理
cifar10_dataset = datasets.CIFAR10(root='./data', # 数据集存在当前目录的data文件夹下train=True,download=True,transform=transform
)# 4. 查看Dataset核心方法(呼应知识点回顾)
print(f"数据集总长度(__len__方法):{len(cifar10_dataset)}") # 输出50000(CIFAR-10训练集数量)# 5. 获取单张图片和对应的标签(使用__getitem__方法)
# 索引可任意选(0~49999,训练集;0~9999,测试集)
index = 123 # 选第124张图片(索引从0开始)
img_tensor, label = cifar10_dataset[index] # __getitem__返回(图像Tensor,标签)# 查看获取的内容
print(f"\n图像Tensor形状(C×H×W):{img_tensor.shape}") # 输出torch.Size([3, 32, 32])(3通道、32×32)
print(f"图片标签(数字):{label}") # 输出0~9的数字(对应CIFAR-10的10个类别)# 6. 定义CIFAR-10的类别名称(将数字标签映射为具体类别)
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'
]
print(f"图片类别(名称):{cifar10_classes[label]}") # 输出标签对应的类别名称# 7. 显示图片(需将Tensor格式转为matplotlib支持的格式)
# 步骤:① 调整维度(C×H×W → H×W×C);② 转为numpy数组;③ 若Tensor是0-1范围,转回0-255(可选,更清晰)
img_show = img_tensor.permute(1, 2, 0).numpy() # permute:交换维度(1=H,2=W,0=C)
img_show = img_show * 255 # 从0-1范围转回0-255(因ToTensor默认归一化到0-1)
img_show = img_show.astype(np.uint8) # 转为8位整数(图像像素的标准格式)# 用matplotlib显示
plt.figure(figsize=(4, 4)) # 设置图片大小
plt.imshow(img_show) # 显示图像
plt.title(f"CIFAR-10 Image: {cifar10_classes[label]} (Index: {index})") # 标题:类别+索引
plt.axis('off') # 隐藏坐标轴
plt.show() # 弹出显示窗口
Duplicate key in file WindowsPath(‘d:/Anaconda/envs/DL39/lib/site-packages/matplotlib/mpl-data/matplotlibrc’), line 263 (‘font.family: simhei’)
数据集总长度(__len__方法):50000
图像Tensor形状(C×H×W):torch.Size([3, 32, 32])
图片标签(数字):2
图片类别(名称):bird
@浙大疏锦行