当前位置: 首页 > news >正文

python打卡day38

Dataset和DataLoader

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式,均继承自torch.utils.data

  • DataLoader类:决定数据如何加载(批量大小batch_size和是否打乱数据顺序shuffle=True/False)
  • Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理(数据路径和预处理transform)

torch.utils.data.Dataset是一个抽象基类,所有数据集都需要继承Dataset并定义两个核心方法

  1. __len__():返回数据集的样本总数
  2. __getitem__(idx):根据索引idx返回对应样本的数据和标签

__getitem__和__len__ 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为,举个例子:

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data)    # 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

DataLoader类就更好理解了,使用DataLoader类的正确流程是先通过Dataset类定义数据的读取方式和预处理,再通过DataLoader设定批次大小等参数进行加载,以一个自定义数据集举个例子

from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data_path, transform=None):self.data = [...]  # 加载数据列表(如文件路径列表)self.transform = transform  # 预处理操作def __len__(self):return len(self.data)def __getitem__(self, idx):# 读取单个样本(如从文件路径加载图像)sample = self.load_sample(self.data[idx])  if self.transform is not None:sample = self.transform(sample)  # 应用预处理return sample, label  # 返回样本和标签# 先创建Dataset实例
dataset = MyDataset(data_path="./data", transform=my_transform) # 假设前面定义了预处理操作transform# 再创建DataLoader实例
dataloader = DataLoader(dataset,batch_size=32,    # 批次大小shuffle=True,     # 打乱数据顺序num_workers=4     # 使用4个线程加载数据
)

为了引入这些概念,我们现在接触一个新的而且非常经典的数据集:MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练

1、用到的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)

torchvision

├── datasets       # 视觉数据集(如 MNIST、CIFAR)

├── transforms     # 视觉数据预处理(如裁剪、翻转、归一化)

├── models         # 预训练模型(如 ResNet、YOLO)

├── utils          # 视觉工具函数(如目标检测后处理)

└── io             # 图像/视频 IO 操作

2、定义预处理操作transform

这里用 torchvision 的 transforms 模块,提供了一系列常用的图像预处理操作

# 数据预处理,该写法非常类似于管道pipeline
# 先归一化,再标准化
transform = transforms.Compose([ # compose用于将多个数据预处理操作按顺序组合成一个整体,参数是一个列表,每个操作是一个元素transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # 标准化,MNIST数据集的均值和标准差,这个值很出名,所以直接使用# 参数格式是元组 (mean_channel1, mean_channel2, ...),由于MNIST是单通道(灰度图),这里只有一个值
])

3、创建dataset实例

torchvision 的 datasets 模块已经预定义了许多常见的数据集,实例化一个数据类就是创建dataset对象了

# 加载MNIST数据集,如果没有会自动下载,pytorch的思路是,数据在加载阶段就预处理结束
# 训练集
train_dataset = datasets.MNIST(root='./data', # 数据存储路径train=True,download=True, # 如果目录下数据不存在则自动下载transform=transform # 应用预处理
)# 测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

3、创建dataloader实例

# 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)

过程就是定义预处理transform ➡ 实例化一个数据集类(创建dataset实例)➡ 创建数据加载器(创建dataloader实例)➡ 后续操作

作业:了解下cifar数据集,尝试获取其中一张图片

cifar-10的图片就是32*32的彩色图,那就存在RGB三个通道上不同的灰度图,分别标准化和反标准化

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms# 设置中文字体
plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei']
plt.rcParams['axes.unicode_minus'] = False    # 正常显示负号# 定义CIFAR-10的均值和标准差
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2470, 0.2435, 0.2616)# 定义预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(cifar_mean, cifar_std)
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)# CIFAR-10的类别标签
classes = ('飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车')# 随机选择一个样本
index = np.random.randint(0, len(train_dataset))
image, label = train_dataset[index]# 反标准化操作 (针对3通道图像)
image = image.clone()  # 避免修改原始数据
for i in range(3):  # 对RGB三个通道分别反标准化image[i] = image[i] * cifar_std[i] + cifar_mean[i]# 转换为numpy并调整维度 (PyTorch: [C,H,W] → Matplotlib: [H,W,C])
image = np.transpose(image.numpy(), (1, 2, 0))# 显示图像
plt.figure(figsize=(5, 5))
plt.imshow(image)
plt.title(f'随机抽取的样本 - 标签: {classes[label]}')
plt.axis('off')
plt.show()

最后输出图片很模糊,可能因为数据集本身分辨率就不高,plot参数设置了但是中文还是没显示出来,很奇怪搞不懂

@浙大疏锦行

相关文章:

  • 物流项目第七期(路线规划之Neo4j的应用)
  • ImageMagick 是默认使用 CPU 来处理图像,也具备利用 GPU 加速的潜力
  • 从“学术杠精”到“学术创新”
  • 使用 mysqldump 获取 MySQL 表的完整创建 DDL
  • 如何在WordPress网站中添加相册/画廊
  • PyTorch 2.1新特性:TorchDynamo如何实现30%训练加速(原理+自定义编译器开发)
  • 车载通信网络 --- OSI模型:网络层
  • 国芯思辰| 同步降压转换器CN2020应用于智能电视,替换LMR33620
  • 数据结构期末模拟试卷
  • 2025年上半年第2批信息系统项目管理师论文真题解析与范文
  • pgsql 查看每张表大小
  • Python实战:打造高效通讯录管理系统
  • DD3118替代GL3213S 免晶振USB3.0读卡器控制芯片
  • C3P0连接池的使用方法和源码分析
  • 基于Python技术的面部考勤微信小程序的设计与实现
  • WPF【11_2】WPF实战-重构与美化(Entity Framework)-示例
  • Python深度挖掘:openpyxl与pandas高效数据处理实战
  • [问题解决]:Unable to find image ‘containrrr/watchtower:latest‘ locally
  • Python实现自动物体识别---基于深度学习的AI应用实战
  • Orpheus-TTS:AI文本转语音,免费好用的TTS系统
  • 网站的英文/谷歌排名
  • 在什么网站做贸易好/seo零基础培训
  • 网页设计与网站开发期末/获客软件排名前十名
  • 政务网站建设论文/快速整站排名seo教程
  • 网站后台数据库备份怎么做/推广一单500
  • 做网站都是需要什么/企业qq下载