【PyTorch】PyTorch中的数据预处理操作
PyTorch深度学习总结
第十二章 PyTorch中的数据预处理操作
文章目录
- PyTorch深度学习总结
- 前言
- 一、`torch.utils.data` 模块
- 1. 核心组件
- 2.常用工具类
- 3. 数据采样器
- 二、常用函数
前言
上文介绍了PyTorch中torch.nn模块的全连接层,本文将介绍PyTorch中torch.utils.data
模块的数据处理操作:
一、torch.utils.data
模块
torch.utils.data
是 PyTorch
中用于数据处理和加载的重要模块,它提供了一系列工具和类,方便用户对数据集进行管理和操作。以下是对该模块的详细介绍:
1. 核心组件
1.1 Dataset 类
- 作用:
Dataset
类是一个抽象基类,用于表示数据集。用户需要继承这个类并实现__len__
和__getitem__
方法,以定义数据集的长度和如何获取数据集中的单个样本。- 示例:
import torch from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]data = [1, 2, 3, 4, 5] dataset = MyDataset(data) print(len(dataset)) print(dataset[2])
- 解释:在这个示例中,我们创建了一个自定义的数据集类
MyDataset
,它接受一个列表作为数据。__len__
方法返回数据集的长度,__getitem__
方法根据索引返回数据集中的单个样本。
1.2 DataLoader 类
- 作用:
DataLoader
类用于将数据集封装成一个可迭代的对象,支持批量加载数据、打乱数据顺序、多线程加载等功能,方便在训练模型时使用。- 示例:
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=2, shuffle=True) for batch in dataloader:print(batch)
- 解释:在这个示例中,我们将之前创建的
dataset
封装成一个DataLoader
对象,设置批量大小为 2,并开启数据打乱功能。然后通过迭代DataLoader
对象,可以逐批获取数据。
2.常用工具类
2.1 Subset 类
- 作用:
Subset
类用于创建数据集的子集,通过指定数据集和索引列表来获取子集中的数据。- 示例:
from torch.utils.data import Subsetsubset = Subset(dataset, [0, 2, 4]) print(len(subset)) print(subset[1])
- 解释:在这个示例中,我们创建了
dataset
的一个子集subset
,只包含索引为 0、2、4 的样本。
2.2 ConcatDataset 类
- 作用:
ConcatDataset
类用于将多个数据集合并成一个数据集。- 示例:
from torch.utils.data import ConcatDatasetdataset1 = MyDataset([1, 2, 3]) dataset2 = MyDataset([4, 5, 6]) concat_dataset = ConcatDataset([dataset1, dataset2]) print(len(concat_dataset)) print(concat_dataset[4])
- 解释:在这个示例中,我们将两个自定义的数据集
dataset1
和dataset2
合并成一个新的数据集concat_dataset
。
3. 数据采样器
3.1 RandomSampler 类
- 作用:
RandomSampler
类用于随机采样数据集中的样本,常用于打乱数据顺序。- 示例:
from torch.utils.data import RandomSamplersampler = RandomSampler(dataset) dataloader = DataLoader(dataset, batch_size=2, sampler=sampler) for batch in dataloader:print(batch)
- 解释:在这个示例中,我们使用
RandomSampler
类对数据集进行随机采样,然后将采样器传递给DataLoader
对象,这样在加载数据时会随机获取样本。
3.2 SequentialSampler 类
- 作用:
SequentialSampler
类用于按顺序采样数据集中的样本。- 示例:
from torch.utils.data import SequentialSamplersampler = SequentialSampler(dataset) dataloader = DataLoader(dataset, batch_size=2, sampler=sampler) for batch in dataloader:print(batch)
- 解释:在这个示例中,我们使用
SequentialSampler
类对数据集进行顺序采样,然后将采样器传递给DataLoader
对象,这样在加载数据时会按顺序获取样本。
二、常用函数
操作函数 | 功能 |
---|---|
torch.utils.data.TensorDataset() | 将数据处理为张量 |
torch.utils.data.ConcatDataset() | 连接多个数据集 |
torch.utils.data.Subset() | 根据索引获取数据集的子集 |
torch.utils.data.DataLoader() | 数据加载器 |
torch.utils.data.random_split() | 随机将数据集拆分为给定长度的非重叠新数据集 |