当前位置: 首页 > news >正文

Python Day38

Task:
1.Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
2.Dataloader类
3.minist手写数据集的了解


1. Dataset 类的 __getitem____len__ 方法

在 PyTorch (或类似深度学习框架) 中,Dataset 是一个抽象基类,用于表示你的数据。它通常用于将原始数据(例如图像文件、文本文件、CSV 数据等)处理成模型可以直接消费的格式。

Dataset 类有两个核心的特殊方法,它们是 Python 的“魔法方法”:

  • __len__(self):

    • 作用: 这个方法必须返回数据集中样本的总数量。
    • 实现: 当你创建一个 Dataset 的子类时,你需要实现它来告诉 PyTorch 这个数据集有多大。
    • 用处: Dataloader 需要知道总长度才能正确地进行批处理、洗牌和分发数据。
    • 示例:
      class MyDataset(Dataset):def __init__(self, data_list):self.data = data_list # 假设data_list是你的数据源def __len__(self):return len(self.data) # 返回数据源的长度def __getitem__(self, idx):# ... 具体实现将在下面说明pass
      
  • __getitem__(self, idx):

    • 作用: 这个方法用于根据给定的索引 idx 返回数据集中的一个样本。
    • 实现: 这是最关键的部分。你需要在其中定义如何加载、预处理(如图像变换、文本编码)并返回一个样本及其对应的标签。
    • 返回类型: 通常,它返回一个元组或字典,其中包含一个数据样本和其对应的标签。例如 (image_tensor, label_tensor)
    • 用处: 当 Dataloader 需要获取一个批次的数据时,它会内部多次调用 __getitem__ 来收集单个样本。
    • 示例:
      import torch
      from torch.utils.data import Datasetclass CustomImageDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transform # 用于图像预处理的转换def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]label = self.labels[idx]# 假设这里是加载图像的逻辑 (实际会用Pillow等库)# 为了演示,我们创建一个虚拟图像tensorimage = torch.randn(3, 224, 224) # 3 channels, 224x224 pixelsif self.transform:image = self.transform(image) # 应用预处理return image, label # 返回图像张量和标签
      

总结 Dataset 的作用和特殊方法:

Dataset 类负责:

  1. 数据抽象: 将原始数据封装成一个可迭代、可索引的对象。
  2. 数据加载: 在 __getitem__ 中处理从文件系统或内存中加载单个数据项的逻辑。
  3. 数据预处理: 在 __getitem__ 中应用必要的预处理步骤(如归一化、裁剪、数据增强)。
  4. 提供索引: __len____getitem__ 使得数据集可以通过索引访问,并知道其总大小。

2. DataLoader

DataLoader 是 PyTorch 中一个非常强大的工具,它建立在 Dataset 之上,负责高效地加载和批处理数据。它的核心功能是:

  • 批处理 (Batching): 将单个样本组合成批次,这是深度学习训练的常用方式,因为它可以提高计算效率,并有助于梯度下降的稳定。
  • 洗牌 (Shuffling): 在每个 epoch 开始时随机打乱数据,以防止模型学习到数据中的顺序模式,并提高模型的泛化能力。
  • 多进程数据加载 (Multiprocessing Data Loading): 可以使用多个工作进程并行加载数据,从而减少数据加载成为训练瓶颈的可能性。
  • 内存固定 (Pin Memory): 可以将张量加载到 CUDA 固定内存中,这可以加快数据传输到 GPU 的速度。

DataLoader 的主要参数:

  • dataset: 必须是 torch.utils.data.Dataset 的实例。这是 DataLoader 从中获取数据的来源。
  • batch_size: 每个批次包含的样本数量。
  • shuffle: 布尔值,如果设置为 True,则在每个 epoch 开始时打乱数据。
  • num_workers: 用于数据加载的子进程数量。设置为 0 意味着数据将在主进程中加载。大于 0 会开启多进程,通常能加快加载速度,但也需要更多内存。
  • drop_last: 布尔值,如果设置为 True,则如果数据集大小不能被 batch_size 整除,则最后一个不完整的批次将被丢弃。
  • collate_fn: 可选参数,一个函数,用于如何将单个样本列表合并成一个批次。默认情况下,它会尝试堆叠张量。如果你有复杂的数据结构(如变长序列),你可能需要自定义这个函数。

DataLoader 的使用:

DataLoader 是一个可迭代对象。你可以直接在 for 循环中使用它来获取批次数据。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np# 假设我们有一个简单的Dataset
class SimpleDataset(Dataset):def __init__(self, num_samples=100):self.data = torch.randn(num_samples, 10) # 100个样本,每个样本10个特征self.labels = torch.randint(0, 2, (num_samples,)) # 100个标签,0或1def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 创建数据集实例
my_dataset = SimpleDataset(num_samples=100)# 创建DataLoader实例
train_loader = DataLoader(dataset=my_dataset,batch_size=16,shuffle=True,num_workers=0) # 简单示例,不使用多进程# 迭代DataLoader获取批次数据
for epoch in range(5): # 假设训练5个epochprint(f"\nEpoch {epoch+1}")for batch_idx, (data, labels) in enumerate(train_loader):print(f"  Batch {batch_idx+1}: data shape = {data.shape}, labels shape = {labels.shape}")# 在这里执行模型的前向传播、计算损失、反向传播等训练步骤if batch_idx >= 2: # 只打印前3个批次,避免输出过多break

DataLoaderDataset 的协作:

  • DataLoader 接收一个 Dataset 对象。
  • DataLoader 需要一个批次数据时,它会:
    1. 如果 shuffle=True,它会首先打乱 Dataset 的索引。
    2. 它会选择 batch_size 个索引。
    3. 对于每个选定的索引,它会调用 Dataset__getitem__(idx) 方法来获取单个样本。
    4. 它将这些单个样本集合起来(默认通过 torch.stacktorch.cat),形成一个批次张量。
    5. 最终将批次张量返回给你的训练循环。

3. MNIST 手写数字数据集的了解

MNIST (Modified National Institute of Standards and Technology) 是一个经典的、广泛使用的计算机视觉数据集,被誉为“深度学习的 Hello World”。

主要特点:

  1. 内容: 包含大量手写数字的灰度图像。
  2. 类别: 10 个类别,对应数字 0 到 9。
  3. 图像大小: 每张图像都是 28x28 像素。
  4. 数据量:
    • 训练集: 60,000 张图像,用于训练模型。
    • 测试集: 10,000 张图像,用于评估模型的性能。
  5. 图像格式: 灰度图像,每个像素的值通常在 0 到 255 之间,表示像素亮度。

MNIST 的重要性:

  • 入门级: 简单且足够小,适合初学者学习深度学习的基本概念和 PyTorch 的使用。
  • 基准: 由于其标准化和广泛使用,它经常作为新算法和模型架构的初步测试基准。
  • 低计算需求: 训练一个在 MNIST 上表现良好的模型通常不需要强大的 GPU,普通 CPU 也能完成。

PyTorch 中使用 MNIST:

PyTorch 的 torchvision 库提供了方便的工具来下载和加载 MNIST 数据集。

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 1. 定义数据转换 (Transformations)
# MNIST图像是PIL.Image类型,需要转换为Tensor,并进行归一化。
# 归一化是常用的预处理步骤,将像素值缩放到一个特定范围(例如0到1,或-1到1)。
# 对于MNIST,通常是 (mean=0.1307, std=0.3081),这是根据整个MNIST数据集计算得出的。
transform = transforms.Compose([transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为FloatTensor,并除以255将像素值缩放到0-1transforms.Normalize((0.1307,), (0.3081,)) # 归一化,(mean,) (std,),对于灰度图像是单通道
])# 2. 下载并加载训练数据集
# root: 数据存放的根目录
# train=True: 获取训练集
# download=True: 如果数据不存在,则下载
# transform: 应用上述定义的转换
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 3. 下载并加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 4. 创建 DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) # num_workers可以根据你的CPU核心数调整
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) # 测试集通常不打乱# 5. 遍历训练数据 (示例)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")for batch_idx, (data, target) in enumerate(train_loader):print(f"训练批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")# data.shape 会是 [batch_size, 1, 28, 28] (1是通道数,28x28是图像尺寸)# target.shape 会是 [batch_size]break # 只打印第一个批次# 6. 遍历测试数据 (示例)
for batch_idx, (data, target) in enumerate(test_loader):print(f"测试批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")break

相关文章:

  • QT-Creator安装教程(windows)
  • 2.2.2 06年T1
  • Python训练营---Day40
  • 【笔记】Windows 系统安装 Scoop 包管理工具
  • 在线制作幼教早教行业自适应网站教程
  • E. Melody 【CF1026 (Div. 2)】 (求欧拉路径之Hierholzer算法)
  • PHP7+MySQL5.6 查立得源码授权系统DNS验证版
  • GEARS以及与基础模型结合
  • 英语复习笔记 2
  • 彻底理解 JavaScript 浅拷贝与深拷贝:原理、实现与应用
  • USB MSC
  • 04-redis-分布式锁-edisson
  • 后端项目中静态文案国际化语言包构建选型
  • 【计算机网络】fork()+exec()创建新进程(僵尸进程及孤儿进程)
  • 城市内涝精准监测・智能预警・高效应对:治理方案解析
  • 拉深工艺模块——回转体拉深件毛坯尺寸的确定(一)
  • 为什么建立 TCP 连接时,初始序列号不固定?
  • Linux多线程(六)之线程控制4【线程ID及进程地址空间布局】
  • 使用 SpyGlass Power Verify 解决方案中的规则
  • 正点原子AU15开发板!板载40G QSFP、PCIe3.0x8和FMC LPC等接口,性能强悍!
  • 做招聘网站都需要什么手续/免费制作网站
  • 潍坊专业网站建设多少钱/网站一般怎么推广
  • 网站定制/全网营销推广怎么做
  • 网站建设哪个好/百度搜索怎么优化
  • 网站技术维护/网络推广可做哪些方面
  • 申请域名之后如何做网站/地推的60种方法