Python Day38
Task:
1.Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
2.Dataloader类
3.minist手写数据集的了解
1. Dataset
类的 __getitem__
和 __len__
方法
在 PyTorch (或类似深度学习框架) 中,Dataset
是一个抽象基类,用于表示你的数据。它通常用于将原始数据(例如图像文件、文本文件、CSV 数据等)处理成模型可以直接消费的格式。
Dataset
类有两个核心的特殊方法,它们是 Python 的“魔法方法”:
-
__len__(self)
:- 作用: 这个方法必须返回数据集中样本的总数量。
- 实现: 当你创建一个
Dataset
的子类时,你需要实现它来告诉 PyTorch 这个数据集有多大。 - 用处: Dataloader 需要知道总长度才能正确地进行批处理、洗牌和分发数据。
- 示例:
class MyDataset(Dataset):def __init__(self, data_list):self.data = data_list # 假设data_list是你的数据源def __len__(self):return len(self.data) # 返回数据源的长度def __getitem__(self, idx):# ... 具体实现将在下面说明pass
-
__getitem__(self, idx)
:- 作用: 这个方法用于根据给定的索引
idx
返回数据集中的一个样本。 - 实现: 这是最关键的部分。你需要在其中定义如何加载、预处理(如图像变换、文本编码)并返回一个样本及其对应的标签。
- 返回类型: 通常,它返回一个元组或字典,其中包含一个数据样本和其对应的标签。例如
(image_tensor, label_tensor)
。 - 用处: 当 Dataloader 需要获取一个批次的数据时,它会内部多次调用
__getitem__
来收集单个样本。 - 示例:
import torch from torch.utils.data import Datasetclass CustomImageDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transform # 用于图像预处理的转换def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]label = self.labels[idx]# 假设这里是加载图像的逻辑 (实际会用Pillow等库)# 为了演示,我们创建一个虚拟图像tensorimage = torch.randn(3, 224, 224) # 3 channels, 224x224 pixelsif self.transform:image = self.transform(image) # 应用预处理return image, label # 返回图像张量和标签
- 作用: 这个方法用于根据给定的索引
总结 Dataset
的作用和特殊方法:
Dataset
类负责:
- 数据抽象: 将原始数据封装成一个可迭代、可索引的对象。
- 数据加载: 在
__getitem__
中处理从文件系统或内存中加载单个数据项的逻辑。 - 数据预处理: 在
__getitem__
中应用必要的预处理步骤(如归一化、裁剪、数据增强)。 - 提供索引:
__len__
和__getitem__
使得数据集可以通过索引访问,并知道其总大小。
2. DataLoader
类
DataLoader
是 PyTorch 中一个非常强大的工具,它建立在 Dataset
之上,负责高效地加载和批处理数据。它的核心功能是:
- 批处理 (Batching): 将单个样本组合成批次,这是深度学习训练的常用方式,因为它可以提高计算效率,并有助于梯度下降的稳定。
- 洗牌 (Shuffling): 在每个 epoch 开始时随机打乱数据,以防止模型学习到数据中的顺序模式,并提高模型的泛化能力。
- 多进程数据加载 (Multiprocessing Data Loading): 可以使用多个工作进程并行加载数据,从而减少数据加载成为训练瓶颈的可能性。
- 内存固定 (Pin Memory): 可以将张量加载到 CUDA 固定内存中,这可以加快数据传输到 GPU 的速度。
DataLoader
的主要参数:
dataset
: 必须是torch.utils.data.Dataset
的实例。这是DataLoader
从中获取数据的来源。batch_size
: 每个批次包含的样本数量。shuffle
: 布尔值,如果设置为True
,则在每个 epoch 开始时打乱数据。num_workers
: 用于数据加载的子进程数量。设置为 0 意味着数据将在主进程中加载。大于 0 会开启多进程,通常能加快加载速度,但也需要更多内存。drop_last
: 布尔值,如果设置为True
,则如果数据集大小不能被batch_size
整除,则最后一个不完整的批次将被丢弃。collate_fn
: 可选参数,一个函数,用于如何将单个样本列表合并成一个批次。默认情况下,它会尝试堆叠张量。如果你有复杂的数据结构(如变长序列),你可能需要自定义这个函数。
DataLoader
的使用:
DataLoader
是一个可迭代对象。你可以直接在 for
循环中使用它来获取批次数据。
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np# 假设我们有一个简单的Dataset
class SimpleDataset(Dataset):def __init__(self, num_samples=100):self.data = torch.randn(num_samples, 10) # 100个样本,每个样本10个特征self.labels = torch.randint(0, 2, (num_samples,)) # 100个标签,0或1def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 创建数据集实例
my_dataset = SimpleDataset(num_samples=100)# 创建DataLoader实例
train_loader = DataLoader(dataset=my_dataset,batch_size=16,shuffle=True,num_workers=0) # 简单示例,不使用多进程# 迭代DataLoader获取批次数据
for epoch in range(5): # 假设训练5个epochprint(f"\nEpoch {epoch+1}")for batch_idx, (data, labels) in enumerate(train_loader):print(f" Batch {batch_idx+1}: data shape = {data.shape}, labels shape = {labels.shape}")# 在这里执行模型的前向传播、计算损失、反向传播等训练步骤if batch_idx >= 2: # 只打印前3个批次,避免输出过多break
DataLoader
和 Dataset
的协作:
DataLoader
接收一个Dataset
对象。- 当
DataLoader
需要一个批次数据时,它会:- 如果
shuffle=True
,它会首先打乱Dataset
的索引。 - 它会选择
batch_size
个索引。 - 对于每个选定的索引,它会调用
Dataset
的__getitem__(idx)
方法来获取单个样本。 - 它将这些单个样本集合起来(默认通过
torch.stack
或torch.cat
),形成一个批次张量。 - 最终将批次张量返回给你的训练循环。
- 如果
3. MNIST 手写数字数据集的了解
MNIST (Modified National Institute of Standards and Technology) 是一个经典的、广泛使用的计算机视觉数据集,被誉为“深度学习的 Hello World”。
主要特点:
- 内容: 包含大量手写数字的灰度图像。
- 类别: 10 个类别,对应数字 0 到 9。
- 图像大小: 每张图像都是 28x28 像素。
- 数据量:
- 训练集: 60,000 张图像,用于训练模型。
- 测试集: 10,000 张图像,用于评估模型的性能。
- 图像格式: 灰度图像,每个像素的值通常在 0 到 255 之间,表示像素亮度。
MNIST 的重要性:
- 入门级: 简单且足够小,适合初学者学习深度学习的基本概念和 PyTorch 的使用。
- 基准: 由于其标准化和广泛使用,它经常作为新算法和模型架构的初步测试基准。
- 低计算需求: 训练一个在 MNIST 上表现良好的模型通常不需要强大的 GPU,普通 CPU 也能完成。
PyTorch 中使用 MNIST:
PyTorch 的 torchvision
库提供了方便的工具来下载和加载 MNIST 数据集。
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 1. 定义数据转换 (Transformations)
# MNIST图像是PIL.Image类型,需要转换为Tensor,并进行归一化。
# 归一化是常用的预处理步骤,将像素值缩放到一个特定范围(例如0到1,或-1到1)。
# 对于MNIST,通常是 (mean=0.1307, std=0.3081),这是根据整个MNIST数据集计算得出的。
transform = transforms.Compose([transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为FloatTensor,并除以255将像素值缩放到0-1transforms.Normalize((0.1307,), (0.3081,)) # 归一化,(mean,) (std,),对于灰度图像是单通道
])# 2. 下载并加载训练数据集
# root: 数据存放的根目录
# train=True: 获取训练集
# download=True: 如果数据不存在,则下载
# transform: 应用上述定义的转换
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 3. 下载并加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 4. 创建 DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) # num_workers可以根据你的CPU核心数调整
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) # 测试集通常不打乱# 5. 遍历训练数据 (示例)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")for batch_idx, (data, target) in enumerate(train_loader):print(f"训练批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")# data.shape 会是 [batch_size, 1, 28, 28] (1是通道数,28x28是图像尺寸)# target.shape 会是 [batch_size]break # 只打印第一个批次# 6. 遍历测试数据 (示例)
for batch_idx, (data, target) in enumerate(test_loader):print(f"测试批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")break