当前位置: 首页 > news >正文

CIFAR10 数据集自定义处理方法

CIFAR10 数据集自定义处理方法

可以自定义训练集和测试集中不同类别的样本的数量。可用于模拟类别不平衡问题,存在混淆数据问题。

import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random

# 自定义数据集类,继承自 torch.utils.data.Dataset
class CustomCIFAR10Dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        """
        自定义数据集类
        :param images: 图像数据,numpy 数组格式
        :param labels: 标签数据,numpy 数组格式
        :param transform: 可选的图像预处理转换
        """
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        image = self.images[index]
        if self.transform:
            image = self.transform(image)
        label = self.labels[index]
        return image, label

def create_custom_dataset(positive_classes, negative_classes, sample_counts=None, transform=None, train=True):
    """
    创建自定义数据集(训练集或测试集)
    :param positive_classes: 正类别的类别列表
    :param negative_classes: 负类别的类别列表
    :param sample_counts: 每个类别的样本数量限制,字典形式 {类: 样本数量}
    :param transform: 图像预处理转换
    :param train: 是否是训练集(True)还是测试集(False)
    :return: 创建的自定义数据集(CustomCIFAR10Dataset)和原始数据集
    """
    # 下载 CIFAR-10 数据集(训练集或测试集)
    dataset = dsets.CIFAR10(root='./data', train=train, download=True, transform=transforms.ToTensor())
    images = dataset.data  # numpy array, shape [N, 32, 32, 3]
    targets = np.array(dataset.targets)  # shape [N]
    
    new_images = []
    new_labels = []
    selected_global_indices = []

    for cls in np.concatenate((positive_classes, negative_classes)):
        # 获取当前类别的样本索引
        indices = np.where(targets == cls)[0]
        
        # 如果有样本数量限制,则抽取样本
        if sample_counts is not None and cls in sample_counts:
            num_samples = min(sample_counts[cls], len(indices))
            selected_indices = np.random.choice(indices, num_samples, replace=False)
        else:
            selected_indices = indices
        
        selected_global_indices.extend(selected_indices.tolist())
        
        # 为正类别标签为1,负类别标签为0
        for idx in selected_indices:
            new_images.append(images[idx])
            if cls in positive_classes:
                new_labels.append(1)
            else:
                new_labels.append(0)

    # 转换为 numpy 数组
    new_images = np.array(new_images)
    new_labels = np.array(new_labels)
    
    # 打乱新数据集
    perm = np.random.permutation(len(new_labels))
    new_images = new_images[perm]
    new_labels = new_labels[perm]
    
    # 创建自定义数据集
    custom_dataset = CustomCIFAR10Dataset(new_images, new_labels, transform=transform)
    return custom_dataset, dataset

if __name__ == '__main__':
    # 定义正类别和负类别
    positive_classes = [0, 1, 2, 3, 4]
    negative_classes = [5, 6, 7, 8, 9]
    
    # 定义每个类别需要抽取的样本数量
    sample_counts = {0: 500, 1: 500, 2: 500, 3: 500, 4: 500, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}
    
    # 图像预处理
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    # 训练数据集
    train_dataset, base_train_dataset = create_custom_dataset(positive_classes, negative_classes, sample_counts, transform, train=True)
    print('Training dataset size:', len(train_dataset))

    # 测试数据集
    positive_classes_test = [0]
    negative_classes_test = [5, 6, 7, 8, 9]
    sample_counts_test = {0: 1000, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}
    
    test_dataset, base_test_dataset = create_custom_dataset(positive_classes_test, negative_classes_test, sample_counts_test, transform, train=False)
    print('Test dataset size:', len(test_dataset))

    # 使用 DataLoader 加载数据集
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # 打印加载器中的数据量
    for images, labels in train_loader:
        print(f"Batch size: {len(images)}, Labels: {labels}")
        break

代码详细解释文档

1. 自定义数据集类 CustomCIFAR10Dataset

此类继承自 torch.utils.data.Dataset,用于自定义数据集的管理,具体功能如下:

  • __init__: 初始化方法,接受图像数据、标签数据和可能的图像预处理变换。
  • __len__: 返回数据集的长度,即样本数量。
  • __getitem__: 根据索引返回样本图像和标签,若定义了预处理变换,则应用该变换。
2. create_custom_dataset 函数

此函数用于创建训练集或测试集,并按类别划分和抽样。

  • positive_classes: 正类别的类别列表,标签为 1。
  • negative_classes: 负类别的类别列表,标签为 0。
  • sample_counts: 可选,字典形式,指定每个类别的样本数量限制。如果没有该参数,则使用所有样本。
  • transform: 可选,图像预处理变换。
  • train: 是否为训练集。如果为 True,则加载训练集;如果为 False,则加载测试集。
3. 数据集的处理流程
  • 从 CIFAR-10 下载训练集或测试集,获取图像数据和标签。
  • 根据给定的类别信息,抽取所需类别的图像样本,并为正类分配标签为 1,负类分配标签为 0。
  • 如果有样本数量限制,则从每个类别中随机选择样本。
  • 将抽取的图像和标签打乱顺序,并创建自定义数据集 CustomCIFAR10Dataset
4. 训练集和测试集的使用

在主程序中:

  1. 定义正类别和负类别,以及每个类别的样本数量限制。
  2. 使用 create_custom_dataset 创建训练集和测试集。
  3. 使用 DataLoader 加载数据集,设置批次大小并进行数据打乱。
5. DataLoader 的使用
  • DataLoader 用于加载训练数据,并将其按批次处理。我们将自定义数据集传入 DataLoader 并设置批次大小为 64。
  • 在循环中,打印每个批次的大小和标签信息。
6. 输出示例

运行此代码时,您将看到类似以下的输出:

Training dataset size: 5000
Test dataset size: 3500
Batch size: 64, Labels: tensor([1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1])

优化说明

  • 代码中使用了 np.random.permutation 来打乱数据集的顺序,确保数据的随机性。
  • 自定义数据集和图像预处理功能让代码具有灵活性,能够方便地处理不同任务的需求。
  • 使用 DataLoader 来批量加载数据,提升训练效率。

相关文章:

  • Spring Boot 整合 OpenFeign 教程
  • VitePress由 Vite 和 Vue 驱动的静态站点生成器
  • 自然语言处理(5)—— 中文分词
  • 高等数学-第七版-上册 选做记录 习题5-2
  • Linux——线程
  • 构音障碍(Dysarthria)研究全景总结(1996–2024)
  • 在Mac M1/M2芯片上完美安装DeepCTR库:避坑指南与实战验证
  • systemd-networkd 的 /etc/systemd/network/*.network 能不能一个文件配置多块网卡?不能
  • [01-04-02].第20节:PyQt5库初识及实现简易计算器
  • 网络安全基础
  • 文字变央视级语音转换工具
  • OpenRAND可重复的随机数生成库
  • 元宇宙时代下的 Facebook:机遇与挑战
  • IDEA修改默认作者名称
  • Android Compose 约束布局(ConstraintLayout、Modifier.constrainAs)源码深度剖析(十二)
  • #include <hello.h> 与 #include “hello.h“的区别
  • YOLO学习笔记 | YOLO系列算法研究进展及应用综述
  • ant-vue-design 中a-select下拉选择框全局自定义滚动条样式
  • 探秘格式化:数据危机与恢复之道
  • Apache Seatunnel
  • 外交部:各方应为俄乌双方恢复直接对话创造条件
  • 极限拉扯上任巴西,安切洛蒂开启夏窗主帅大挪移?
  • 巫蛊:文化的历史暗流
  • 政策一视同仁引导绿色转型,企业战略回应整齐划一?
  • 10名“鬼火少年”凌晨结队在城区飙车,警方:涉非法改装,正处理
  • 美国长滩港货运量因关税暴跌三成,港口负责人:货架要空了