当前位置: 首页 > news >正文

Dataset和Dataloader

什么是Dataset和Dataloader

  • Dataset指定了数据集包含了什么,可以是自定义数据集,也可以是以及官方数据集
  • Dataloader指定了这个数据集应该以怎样的方式进行加载

定义Dataset

自定义的Dataset格式如下所示

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self):
        # 定义了数据集包含了什么东西
        self.x = []
        self.y = []

    def __len__(self):
        # 返回数据集的总长度
        return len(...)

    def __getitem__(self, idx):
        # 当数据集被读取时,需要返回的数据
        ...
        return self.x[idx], self.y[idx]

案例1:导入两个列表到Dataset

from torch.utils.data import Dataset, DataLoader


class NewDataset(Dataset):
    def __init__(self):
        self.x = [i for i in range(12)]
        self.y = [i * 2 for i in range(12)]

    def __getitem__(self, item):
        return self.x[item], self.y[item]

    def __len__(self):
        return len(self.x)


if __name__ == '__main__':
    newdataset = NewDataset()
    newdataloader = DataLoader(newdataset)
    for x_i, y_i in newdataloader:
        print(x_i, y_i)
    newdataloader = DataLoader(newdataset, batch_size=2)
    for x_i, y_i in newdataloader:
        print(x_i, y_i)
    newdataloader = DataLoader(newdataset, batch_size=4, shuffle=True)
    for x_i, y_i in newdataloader:
        print(x_i, y_i)

案例2:导入Excel数据到Dataset

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader, Dataset


class MyDataset(Dataset):
    def __init__(self):
        filename = "./anli2/data.xlsx"
        data = pd.read_excel(filename)
        self.x1 = data['x1']
        self.x2 = data['x2']
        self.x3 = data['x3']
        self.x4 = data['x4']
        self.y = data['y']

    def __len__(self):
        return len(self.x1)

    def __getitem__(self, item):
        return self.x1[item], self.x2[item], self.x3[item], self.x4[item], self.y[item]


if __name__ == '__main__':
    mydataset = MyDataset()
    mydataloader = DataLoader(mydataset, shuffle=True, batch_size=4)
    for x1, x2, x3, x4, y in mydataloader:
        print(f"x1={x1},x2={x2},x3={x3},x4={x4},y={y}")

案例3:导入图像数据集

# -*- coding: utf-8 -*-
import os
import cv2 as cv
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np


class MyImageDataset(Dataset):
    def __init__(self):
        image_root = r"anli3/image"
        self.file_path_list = []
        dir_name = []
        self.labels = []

        for root, dirs, files in os.walk(image_root):
            if dirs:
                dir_name = dirs
            for file_i in files:
                file_i_full_path = os.path.join(root, file_i)
                self.file_path_list.append(file_i_full_path)
                label = root.split(os.sep)[-1]
                self.labels.append(label)

    def __len__(self):
        return len(self.file_path_list)

    def __getitem__(self, item):
        img = cv.imread(self.file_path_list[item])
        img = cv.resize(img, dsize=(256, 256))
        # 原先的shape为[1,256,256,3]
        # 要将3调换到1的后面
        img = np.transpose(img, (2, 1, 0))
        img_tensor = torch.from_numpy(img)
        label = self.labels[item]
        return img_tensor, label


if __name__ == '__main__':
    mydataset = MyImageDataset()
    mydataloader = DataLoader(mydataset, batch_size=4, shuffle=True, num_workers=4)
    for x_i, y_i in mydataloader:
        print(x_i.shape, y_i)
for root, dirs, files in os.walk(image_root):

它是 Python 中 os 模块的一部分。os.walk() 递归遍历指定目录及其子目录,返回三个值:根目录、子目录和文件列表

label = root.split(os.sep)[-1]

使用文件路径分隔符(os.sep)将字符串 root 分割成一个列表。os.sep 是一个在不同操作系统中定义的路径分隔符,Windows 中为 \,而在 Unix/Linux 中为 /

相关文章:

  • 解锁云原生后端开发新姿势:腾讯云大模型API实战攻略
  • 微调大模型:LoRA、PEFT、RLHF 简介
  • 二分查找------练习2
  • Numpy 简单学习【学习笔记】
  • 基于CNN的FashionMNIST数据集识别6——ResNet模型
  • Python 异步编程
  • MIT6.5840 lab3A
  • llama源码学习·model.py[7]Transformer类
  • gcc -fPIC 选项
  • 浅谈Qt事件子系统——以可拖动的通用Widget为例子
  • AI 驱动视频处理与智算革新:蓝耘MaaS释放海螺AI视频生产力
  • one-hot标签详解
  • 6.4考研408数据结构图论核心知识点深度解析
  • DHCPv6 Stateless Vs Stateful Vs Stateless Stateful
  • RAG文本分块的魔法与智慧:传统分块与延迟分块,选哪个?
  • 程序代码篇---Pyqt的密码界面
  • Jetpack Compose 选项卡控件实现
  • 数据结构-二叉树
  • 【Linux 维测专栏 2 -- Deadlock detection介绍】
  • NIO ByteBuffer 总结
  • 国家主席习近平同普京总统举行大范围会谈
  • 国防部:正告菲方停止以任何方式冲撞中方核心利益
  • 老铺黄金拟配售募资近27亿港元,用于门店拓展扩建及补充流动资金等
  • 体坛联播|双杀阿森纳,巴黎晋级欧冠决赛对阵国际米兰
  • 是否有中国公民受印巴冲突影响?外交部:建议中国公民避免前往冲突涉及地点
  • 又一日军“慰安妇”制度受害者去世,大陆在世幸存者仅7人