python打卡day38
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式,均继承自torch.utils.data
- DataLoader类:决定数据如何加载(批量大小batch_size和是否打乱数据顺序shuffle=True/False)
- Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理(数据路径和预处理transform)
torch.utils.data.Dataset是一个抽象基类,所有数据集都需要继承Dataset并定义两个核心方法:
- __len__():返回数据集的样本总数
- __getitem__(idx):根据索引idx返回对应样本的数据和标签
__getitem__和__len__ 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为,举个例子:
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data) # 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2]) # 输出:30
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj)) # 输出:5
DataLoader类就更好理解了,使用DataLoader类的正确流程是先通过Dataset类定义数据的读取方式和预处理,再通过DataLoader设定批次大小等参数进行加载,以一个自定义数据集举个例子
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data_path, transform=None):self.data = [...] # 加载数据列表(如文件路径列表)self.transform = transform # 预处理操作def __len__(self):return len(self.data)def __getitem__(self, idx):# 读取单个样本(如从文件路径加载图像)sample = self.load_sample(self.data[idx]) if self.transform is not None:sample = self.transform(sample) # 应用预处理return sample, label # 返回样本和标签# 先创建Dataset实例
dataset = MyDataset(data_path="./data", transform=my_transform) # 假设前面定义了预处理操作transform# 再创建DataLoader实例
dataloader = DataLoader(dataset,batch_size=32, # 批次大小shuffle=True, # 打乱数据顺序num_workers=4 # 使用4个线程加载数据
)
为了引入这些概念,我们现在接触一个新的而且非常经典的数据集:MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练
1、用到的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
torchvision
├── datasets # 视觉数据集(如 MNIST、CIFAR)
├── transforms # 视觉数据预处理(如裁剪、翻转、归一化)
├── models # 预训练模型(如 ResNet、YOLO)
├── utils # 视觉工具函数(如目标检测后处理)
└── io # 图像/视频 IO 操作
2、定义预处理操作transform
这里用 torchvision 的 transforms 模块,提供了一系列常用的图像预处理操作
# 数据预处理,该写法非常类似于管道pipeline
# 先归一化,再标准化
transform = transforms.Compose([ # compose用于将多个数据预处理操作按顺序组合成一个整体,参数是一个列表,每个操作是一个元素transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # 标准化,MNIST数据集的均值和标准差,这个值很出名,所以直接使用# 参数格式是元组 (mean_channel1, mean_channel2, ...),由于MNIST是单通道(灰度图),这里只有一个值
])
3、创建dataset实例
torchvision 的 datasets 模块已经预定义了许多常见的数据集,实例化一个数据类就是创建dataset对象了
# 加载MNIST数据集,如果没有会自动下载,pytorch的思路是,数据在加载阶段就预处理结束
# 训练集
train_dataset = datasets.MNIST(root='./data', # 数据存储路径train=True,download=True, # 如果目录下数据不存在则自动下载transform=transform # 应用预处理
)# 测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
3、创建dataloader实例
# 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)
过程就是定义预处理transform ➡ 实例化一个数据集类(创建dataset实例)➡ 创建数据加载器(创建dataloader实例)➡ 后续操作
作业:了解下cifar数据集,尝试获取其中一张图片
cifar-10的图片就是32*32的彩色图,那就存在RGB三个通道上不同的灰度图,分别标准化和反标准化
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms# 设置中文字体
plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei']
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号# 定义CIFAR-10的均值和标准差
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2470, 0.2435, 0.2616)# 定义预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(cifar_mean, cifar_std)
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)# CIFAR-10的类别标签
classes = ('飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车')# 随机选择一个样本
index = np.random.randint(0, len(train_dataset))
image, label = train_dataset[index]# 反标准化操作 (针对3通道图像)
image = image.clone() # 避免修改原始数据
for i in range(3): # 对RGB三个通道分别反标准化image[i] = image[i] * cifar_std[i] + cifar_mean[i]# 转换为numpy并调整维度 (PyTorch: [C,H,W] → Matplotlib: [H,W,C])
image = np.transpose(image.numpy(), (1, 2, 0))# 显示图像
plt.figure(figsize=(5, 5))
plt.imshow(image)
plt.title(f'随机抽取的样本 - 标签: {classes[label]}')
plt.axis('off')
plt.show()
最后输出图片很模糊,可能因为数据集本身分辨率就不高,plot参数设置了但是中文还是没显示出来,很奇怪搞不懂
@浙大疏锦行