当前位置: 首页 > news >正文

Dataset和Dataloader

什么是Dataset和Dataloader

  • Dataset指定了数据集包含了什么,可以是自定义数据集,也可以是以及官方数据集
  • Dataloader指定了这个数据集应该以怎样的方式进行加载

定义Dataset

自定义的Dataset格式如下所示

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self):
        # 定义了数据集包含了什么东西
        self.x = []
        self.y = []

    def __len__(self):
        # 返回数据集的总长度
        return len(...)

    def __getitem__(self, idx):
        # 当数据集被读取时,需要返回的数据
        ...
        return self.x[idx], self.y[idx]

案例1:导入两个列表到Dataset

from torch.utils.data import Dataset, DataLoader


class NewDataset(Dataset):
    def __init__(self):
        self.x = [i for i in range(12)]
        self.y = [i * 2 for i in range(12)]

    def __getitem__(self, item):
        return self.x[item], self.y[item]

    def __len__(self):
        return len(self.x)


if __name__ == '__main__':
    newdataset = NewDataset()
    newdataloader = DataLoader(newdataset)
    for x_i, y_i in newdataloader:
        print(x_i, y_i)
    newdataloader = DataLoader(newdataset, batch_size=2)
    for x_i, y_i in newdataloader:
        print(x_i, y_i)
    newdataloader = DataLoader(newdataset, batch_size=4, shuffle=True)
    for x_i, y_i in newdataloader:
        print(x_i, y_i)

案例2:导入Excel数据到Dataset

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader, Dataset


class MyDataset(Dataset):
    def __init__(self):
        filename = "./anli2/data.xlsx"
        data = pd.read_excel(filename)
        self.x1 = data['x1']
        self.x2 = data['x2']
        self.x3 = data['x3']
        self.x4 = data['x4']
        self.y = data['y']

    def __len__(self):
        return len(self.x1)

    def __getitem__(self, item):
        return self.x1[item], self.x2[item], self.x3[item], self.x4[item], self.y[item]


if __name__ == '__main__':
    mydataset = MyDataset()
    mydataloader = DataLoader(mydataset, shuffle=True, batch_size=4)
    for x1, x2, x3, x4, y in mydataloader:
        print(f"x1={x1},x2={x2},x3={x3},x4={x4},y={y}")

案例3:导入图像数据集

# -*- coding: utf-8 -*-
import os
import cv2 as cv
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np


class MyImageDataset(Dataset):
    def __init__(self):
        image_root = r"anli3/image"
        self.file_path_list = []
        dir_name = []
        self.labels = []

        for root, dirs, files in os.walk(image_root):
            if dirs:
                dir_name = dirs
            for file_i in files:
                file_i_full_path = os.path.join(root, file_i)
                self.file_path_list.append(file_i_full_path)
                label = root.split(os.sep)[-1]
                self.labels.append(label)

    def __len__(self):
        return len(self.file_path_list)

    def __getitem__(self, item):
        img = cv.imread(self.file_path_list[item])
        img = cv.resize(img, dsize=(256, 256))
        # 原先的shape为[1,256,256,3]
        # 要将3调换到1的后面
        img = np.transpose(img, (2, 1, 0))
        img_tensor = torch.from_numpy(img)
        label = self.labels[item]
        return img_tensor, label


if __name__ == '__main__':
    mydataset = MyImageDataset()
    mydataloader = DataLoader(mydataset, batch_size=4, shuffle=True, num_workers=4)
    for x_i, y_i in mydataloader:
        print(x_i.shape, y_i)
for root, dirs, files in os.walk(image_root):

它是 Python 中 os 模块的一部分。os.walk() 递归遍历指定目录及其子目录,返回三个值:根目录、子目录和文件列表

label = root.split(os.sep)[-1]

使用文件路径分隔符(os.sep)将字符串 root 分割成一个列表。os.sep 是一个在不同操作系统中定义的路径分隔符,Windows 中为 \,而在 Unix/Linux 中为 /

http://www.dtcms.com/a/85178.html

相关文章:

  • 解锁云原生后端开发新姿势:腾讯云大模型API实战攻略
  • 微调大模型:LoRA、PEFT、RLHF 简介
  • 二分查找------练习2
  • Numpy 简单学习【学习笔记】
  • 基于CNN的FashionMNIST数据集识别6——ResNet模型
  • Python 异步编程
  • MIT6.5840 lab3A
  • llama源码学习·model.py[7]Transformer类
  • gcc -fPIC 选项
  • 浅谈Qt事件子系统——以可拖动的通用Widget为例子
  • AI 驱动视频处理与智算革新:蓝耘MaaS释放海螺AI视频生产力
  • one-hot标签详解
  • 6.4考研408数据结构图论核心知识点深度解析
  • DHCPv6 Stateless Vs Stateful Vs Stateless Stateful
  • RAG文本分块的魔法与智慧:传统分块与延迟分块,选哪个?
  • 程序代码篇---Pyqt的密码界面
  • Jetpack Compose 选项卡控件实现
  • 数据结构-二叉树
  • 【Linux 维测专栏 2 -- Deadlock detection介绍】
  • NIO ByteBuffer 总结
  • WPF控件DataGrid介绍
  • Ubuntu常用命令大全 | 零基础快速上手指南
  • Python环境安装
  • 【C++】内存管理
  • Github 2025-03-23 php开源项目日报Top10
  • MySQL中的锁(全局锁、表锁和行锁)
  • Java19虚拟线程原理详细透析以及企业级使用案例。
  • SpringMVC 的面试题
  • Python Cookbook-4.11 在无须过多援引的情况下创建字典
  • CICDDevOps概述