当前位置: 首页 > news >正文

人工智能-python-深度学习-数据准备

文章目录

  • 数据准备
    • 数据加载器
      • 1. 构建数据类
        • 1.1 Dataset类
        • 1.2 TensorDataset类
        • 为什么要使用这些类?
      • 2. 数据加载器(DataLoader)
        • 为什么使用数据加载器?
    • 数据集加载案例
      • 1. 加载 CSV 数据集
        • 为什么这样做?
      • 2. 加载图片数据集
        • 为什么要这么做?
      • 3. 加载官方数据集
        • 为什么使用官方数据集?
    • 数据探索与清洗
      • 1. 检查图像与标注匹配
        • 为什么这样做?
      • 2. 删除损坏图像
        • 为什么这么做?
      • 3. 数据集划分
        • 为什么划分数据集?
    • 结论


数据准备

数据加载器

1. 构建数据类

在深度学习中,数据加载器(DataLoader)负责高效地加载和批量化数据,以便可以快速输入到模型中进行训练。PyTorch 提供了 DatasetDataLoader 类来简化这一过程。

1.1 Dataset类

Dataset 类是一个抽象类,它允许你自定义如何从数据源中加载数据。通常,你需要继承 Dataset 类并实现以下两个方法:

  • __len__():返回数据集的大小。
  • __getitem__():返回给定索引的数据样本。
  • 在Pytorch中,构建自定义数据加载类通常需要继承torch.utils.data.Dataset并实现以下几个方法:
  1. _init_ 方法
    用于初始化数据集对象:通常在这里加载数据,或者定义如何从存储中获取数据的路径和方法。

    def __init__(self, data, labels):self.data = dataself.labels = labels
    
  2. _len_ 方法
    返回样本数量:需要实现,以便 Dataloader加载器能够知道数据集的大小。

    def __len__(self):return len(self.data)
    
  3. _getitem_ 方法
    根据索引返回样本:将从数据集中提取一个样本,并可能对样本进行预处理或变换。

    def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label
    

​ 如果你需要进行更多的预处理或数据变换,可以在 _getitem_ 方法中添加额外的逻辑。

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 定义数据加载类
class CustomDataset(Dataset):def __init__(self, data, labels):"""初始化数据集:data: 样本数据(例如,一个 NumPy 数组或 PyTorch 张量):labels: 样本标签"""self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):index = min(max(index, 0), len(self.data) - 1)sample = self.data[index]label = self.labels[index]return sample, labeldef test001():# 简单的数据集准备data_x = torch.randn(666, 20, requires_grad=True, dtype=torch.float32)data_y = torch.randn(data_x.shape[0], 1, dtype=torch.float32)dataset = CustomDataset(data_x, data_y)# 随便打印个数据看一下print(dataset[0])if __name__ == "__main__":test001()
1.2 TensorDataset类

TensorDataset 类是 Dataset 的一个简单实现,它将张量打包为元组。对于许多常见的任务(如监督学习),我们可以直接使用 TensorDataset

from torch.utils.data import TensorDatasetdata = torch.randn(100, 3)  # 100个样本,3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签dataset = TensorDataset(data, labels)
为什么要使用这些类?

通过继承或使用 Dataset 类,你可以灵活地处理数据加载的细节,如数据预处理、标签处理、数据增强等。这种方式使得代码更加模块化和可复用。

2. 数据加载器(DataLoader)

DataLoader 是用来批量加载数据的工具,它会自动地将数据分成小批次并进行洗牌。你可以根据需要指定批量大小、是否打乱数据等。

from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=32, shuffle=True)
为什么使用数据加载器?
  • 批处理:通过批量化加载数据,减少内存消耗。
  • 多线程加载:通过指定 num_workers 参数,数据可以并行加载,减少I/O瓶颈,提高训练效率。
  • 数据洗牌:在每个 epoch 之前洗牌数据,避免模型过拟合。

数据集加载案例

1. 加载 CSV 数据集

CSV 格式的文件通常包含表格数据,每一行表示一个样本,列表示不同的特征。在 PyTorch 中,我们可以使用 pandas 库加载 CSV 数据并将其转换为 PyTorch 的张量。

import pandas as pd
import torch
from torch.utils.data import Datasetclass CSVDataset(Dataset):def __init__(self, file_path):self.data = pd.read_csv(file_path)self.features = torch.tensor(self.data.iloc[:, :-1].values, dtype=torch.float32)self.labels = torch.tensor(self.data.iloc[:, -1].values, dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.features[idx], self.labels[idx]
为什么这样做?

CSV 文件格式简单且广泛使用,通常用于存储结构化数据。通过将其加载到 PyTorch 的张量中,能够直接利用 GPU 进行加速处理。

2. 加载图片数据集

对于图像数据集,通常需要加载和预处理图像。这可以通过 torchvision 库来实现,它提供了 ImageFolder 等工具,可以方便地从文件夹中加载图像数据。

from torchvision import datasets, transformstransform = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor(),
])dataset = datasets.ImageFolder('path/to/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
为什么要这么做?

图像数据通常需要进行尺寸调整、归一化等预处理,torchvision.transforms 提供了多种常用的图像变换操作,可以方便地应用到数据加载过程中。

3. 加载官方数据集

PyTorch 提供了许多常见的标准数据集,可以通过 torchvision.datasets 轻松加载,如 MNIST、CIFAR-10 等。加载这些数据集非常简单,可以直接使用预定义的 API。

from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
为什么使用官方数据集?

官方数据集已被预处理,并且通常用于标准化的训练和评估任务,使用它们可以避免处理数据集的繁琐步骤,直接进行模型训练和测试。

数据探索与清洗

1. 检查图像与标注匹配

图像和标注匹配是确保模型训练正确性的第一步。在某些任务中,图像和其对应的标注可能会不匹配。通常需要检查每个图像是否有有效的标注。

import cv2
import osimage_path = 'path/to/image'
annotation_path = 'path/to/annotation'# 检查图像文件是否有效
image = cv2.imread(image_path)
if image is None:print(f"Invalid image file: {image_path}")# 检查标注文件是否有效
with open(annotation_path) as f:annotations = f.readlines()if not annotations:print(f"Invalid annotation file: {annotation_path}")
为什么这样做?

确保图像和标注文件匹配是保证模型训练质量的关键。如果图像和标注不匹配,会导致模型学习到错误的特征,影响训练效果。

2. 删除损坏图像

图像数据集可能包含损坏的图像(如无法打开的文件)。在训练模型之前,需要清理这些损坏的图像。

def check_image_validity(image_path):try:img = cv2.imread(image_path)if img is None:return Falsereturn Trueexcept Exception as e:return False# 遍历数据集,删除损坏图像
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', 'path/to/image3.jpg']
valid_images = [img for img in image_paths if check_image_validity(img)]
为什么这么做?

删除损坏图像有助于确保训练数据集的质量,避免模型因损坏的图像而学到不必要的噪音。

3. 数据集划分

数据集通常需要划分为训练集、验证集和测试集。常见的做法是将 70%-80% 的数据用作训练,剩下的用于验证和测试。可以使用 train_test_split 来完成这一过程。

from sklearn.model_selection import train_test_splitdata = [i for i in range(100)]  # 假设有100个样本
train_data, test_data = train_test_split(data, test_size=0.2)
为什么划分数据集?

数据集划分有助于模型评估。训练集用于训练模型,验证集用于调参,测试集用于评估模型的最终性能。


结论

数据准备是深度学习项目中至关重要的一步。通过合理使用数据加载器和进行必要的数据清洗,可以确保模型能够高效且准确地进行训练。特别是图像和文本等复杂数据类型,适当的预处理和清洗能够极大地提高模型的性能和泛化能力。在进行数据处理时,确保数据完整性和一致性是非常关键的,避免无效数据干扰训练过程。

http://www.dtcms.com/a/350794.html

相关文章:

  • 路径总和。
  • 同一性和斗争性
  • 使用 gemini api + 异步执行,批量翻译文档
  • 【Task04】:向量及多模态嵌入(第三章1、2节)
  • 解锁表格数据处理的高效方法-通用表格识别接口
  • sudo 升级
  • Spring Boot 项目打包成可执行程序
  • 3秒传输大文件:cpolar+Localsend实现跨网络秒传
  • 内核编译 day61
  • Ubuntu安装及配置Git(Ubuntu install and config Git Tools)
  • Linux 磁盘文件系统
  • 【银河麒麟桌面系统】PXE实现arm、x86等多架构安装
  • Linux-进程相关函数
  • Vulkan学到什么程度才算学会
  • 关系轮-和弦图的可视化
  • VPS一键测试脚本NodeQuality,无痕体验+自动导出,服务器测试更轻松
  • illustrator-01
  • 我的项目管理之路-组织级项目管理(二)
  • ASW3642 pin√pin替代TS3DV642方案,可使用原小板只需简单调整外围|ASW3642 HDMI二切一双向切换器方案
  • QT6软件设置图标方法
  • Chrome插件开发:在网页上运行脚本
  • 6种简单方法将大视频从iPhone传输到PC
  • 音频相关数学支持
  • C++ 类型转换深度解析
  • 【应急响应工具教程】Unix/Linux 轻量级工具集Busybox
  • 为什么软解码依然重要?深入理解视频播放与开发应用(视频解码)
  • STM32F103C8T6引脚分布
  • 1. 并发产生背景 并发解决原理
  • 【JavaEE】文件IO操作
  • MyBatis 从入门到精通:一篇就够的实战指南(Java)