人工智能-python-深度学习-数据准备
文章目录
- 数据准备
- 数据加载器
- 1. 构建数据类
- 1.1 Dataset类
- 1.2 TensorDataset类
- 为什么要使用这些类?
- 2. 数据加载器(DataLoader)
- 为什么使用数据加载器?
- 数据集加载案例
- 1. 加载 CSV 数据集
- 为什么这样做?
- 2. 加载图片数据集
- 为什么要这么做?
- 3. 加载官方数据集
- 为什么使用官方数据集?
- 数据探索与清洗
- 1. 检查图像与标注匹配
- 为什么这样做?
- 2. 删除损坏图像
- 为什么这么做?
- 3. 数据集划分
- 为什么划分数据集?
- 结论
数据准备
数据加载器
1. 构建数据类
在深度学习中,数据加载器(DataLoader)负责高效地加载和批量化数据,以便可以快速输入到模型中进行训练。PyTorch 提供了 Dataset
和 DataLoader
类来简化这一过程。
1.1 Dataset类
Dataset
类是一个抽象类,它允许你自定义如何从数据源中加载数据。通常,你需要继承 Dataset
类并实现以下两个方法:
__len__()
:返回数据集的大小。__getitem__()
:返回给定索引的数据样本。- 在Pytorch中,构建自定义数据加载类通常需要继承torch.utils.data.Dataset并实现以下几个方法:
-
_init_ 方法
用于初始化数据集对象:通常在这里加载数据,或者定义如何从存储中获取数据的路径和方法。def __init__(self, data, labels):self.data = dataself.labels = labels
-
_len_ 方法
返回样本数量:需要实现,以便 Dataloader加载器能够知道数据集的大小。def __len__(self):return len(self.data)
-
_getitem_ 方法
根据索引返回样本:将从数据集中提取一个样本,并可能对样本进行预处理或变换。def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label
如果你需要进行更多的预处理或数据变换,可以在 _getitem_ 方法中添加额外的逻辑。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 定义数据加载类
class CustomDataset(Dataset):def __init__(self, data, labels):"""初始化数据集:data: 样本数据(例如,一个 NumPy 数组或 PyTorch 张量):labels: 样本标签"""self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):index = min(max(index, 0), len(self.data) - 1)sample = self.data[index]label = self.labels[index]return sample, labeldef test001():# 简单的数据集准备data_x = torch.randn(666, 20, requires_grad=True, dtype=torch.float32)data_y = torch.randn(data_x.shape[0], 1, dtype=torch.float32)dataset = CustomDataset(data_x, data_y)# 随便打印个数据看一下print(dataset[0])if __name__ == "__main__":test001()
1.2 TensorDataset类
TensorDataset
类是 Dataset
的一个简单实现,它将张量打包为元组。对于许多常见的任务(如监督学习),我们可以直接使用 TensorDataset
。
from torch.utils.data import TensorDatasetdata = torch.randn(100, 3) # 100个样本,3个特征
labels = torch.randint(0, 2, (100,)) # 100个标签dataset = TensorDataset(data, labels)
为什么要使用这些类?
通过继承或使用 Dataset
类,你可以灵活地处理数据加载的细节,如数据预处理、标签处理、数据增强等。这种方式使得代码更加模块化和可复用。
2. 数据加载器(DataLoader)
DataLoader
是用来批量加载数据的工具,它会自动地将数据分成小批次并进行洗牌。你可以根据需要指定批量大小、是否打乱数据等。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=32, shuffle=True)
为什么使用数据加载器?
- 批处理:通过批量化加载数据,减少内存消耗。
- 多线程加载:通过指定
num_workers
参数,数据可以并行加载,减少I/O瓶颈,提高训练效率。 - 数据洗牌:在每个 epoch 之前洗牌数据,避免模型过拟合。
数据集加载案例
1. 加载 CSV 数据集
CSV 格式的文件通常包含表格数据,每一行表示一个样本,列表示不同的特征。在 PyTorch 中,我们可以使用 pandas
库加载 CSV 数据并将其转换为 PyTorch 的张量。
import pandas as pd
import torch
from torch.utils.data import Datasetclass CSVDataset(Dataset):def __init__(self, file_path):self.data = pd.read_csv(file_path)self.features = torch.tensor(self.data.iloc[:, :-1].values, dtype=torch.float32)self.labels = torch.tensor(self.data.iloc[:, -1].values, dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.features[idx], self.labels[idx]
为什么这样做?
CSV 文件格式简单且广泛使用,通常用于存储结构化数据。通过将其加载到 PyTorch 的张量中,能够直接利用 GPU 进行加速处理。
2. 加载图片数据集
对于图像数据集,通常需要加载和预处理图像。这可以通过 torchvision
库来实现,它提供了 ImageFolder
等工具,可以方便地从文件夹中加载图像数据。
from torchvision import datasets, transformstransform = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor(),
])dataset = datasets.ImageFolder('path/to/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
为什么要这么做?
图像数据通常需要进行尺寸调整、归一化等预处理,torchvision.transforms
提供了多种常用的图像变换操作,可以方便地应用到数据加载过程中。
3. 加载官方数据集
PyTorch 提供了许多常见的标准数据集,可以通过 torchvision.datasets
轻松加载,如 MNIST、CIFAR-10 等。加载这些数据集非常简单,可以直接使用预定义的 API。
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
为什么使用官方数据集?
官方数据集已被预处理,并且通常用于标准化的训练和评估任务,使用它们可以避免处理数据集的繁琐步骤,直接进行模型训练和测试。
数据探索与清洗
1. 检查图像与标注匹配
图像和标注匹配是确保模型训练正确性的第一步。在某些任务中,图像和其对应的标注可能会不匹配。通常需要检查每个图像是否有有效的标注。
import cv2
import osimage_path = 'path/to/image'
annotation_path = 'path/to/annotation'# 检查图像文件是否有效
image = cv2.imread(image_path)
if image is None:print(f"Invalid image file: {image_path}")# 检查标注文件是否有效
with open(annotation_path) as f:annotations = f.readlines()if not annotations:print(f"Invalid annotation file: {annotation_path}")
为什么这样做?
确保图像和标注文件匹配是保证模型训练质量的关键。如果图像和标注不匹配,会导致模型学习到错误的特征,影响训练效果。
2. 删除损坏图像
图像数据集可能包含损坏的图像(如无法打开的文件)。在训练模型之前,需要清理这些损坏的图像。
def check_image_validity(image_path):try:img = cv2.imread(image_path)if img is None:return Falsereturn Trueexcept Exception as e:return False# 遍历数据集,删除损坏图像
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', 'path/to/image3.jpg']
valid_images = [img for img in image_paths if check_image_validity(img)]
为什么这么做?
删除损坏图像有助于确保训练数据集的质量,避免模型因损坏的图像而学到不必要的噪音。
3. 数据集划分
数据集通常需要划分为训练集、验证集和测试集。常见的做法是将 70%-80% 的数据用作训练,剩下的用于验证和测试。可以使用 train_test_split
来完成这一过程。
from sklearn.model_selection import train_test_splitdata = [i for i in range(100)] # 假设有100个样本
train_data, test_data = train_test_split(data, test_size=0.2)
为什么划分数据集?
数据集划分有助于模型评估。训练集用于训练模型,验证集用于调参,测试集用于评估模型的最终性能。
结论
数据准备是深度学习项目中至关重要的一步。通过合理使用数据加载器和进行必要的数据清洗,可以确保模型能够高效且准确地进行训练。特别是图像和文本等复杂数据类型,适当的预处理和清洗能够极大地提高模型的性能和泛化能力。在进行数据处理时,确保数据完整性和一致性是非常关键的,避免无效数据干扰训练过程。