当前位置: 首页 > news >正文

PyTorch数据处理工具箱(utils.data简介)

utils.data简介

utils.data包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义数据集需
要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供数据
的大小(size),后者通过给定索引获取数据和标签。__getitem__一次只能获取一个数
据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。首
先我们来定义一个简单的数据集,然后通过具体使用Dataset及DataLoader,给读者一个直
观的认识。

1)导入需要的模块。

import torch
from torch.utils import data
import numpy as np

2)定义获取数据集的类。

该类继承基类Dataset,自定义一个数据集及对应标签。

class TestDataset(data.Dataset):#继承Dataset
def __init__(self):
self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#一些由2维向量表示的数据集
self.Label=np.asarray([0,1,0,1,2])#这是数据集对应的标签
def __getitem__(self, index):
#把numpy转换为Tensor
txt=torch.from_numpy(self.Data[index])
label=torch.tensor(self.Label[index])
return txt,label
def __len__(self):
return len(self.Data)

3)获取数据集中数据。

Test=TestDataset()
print(Test[2]) #相当于调用__getitem__(2)
print(Test.__len__())
#输出:
#(tensor([2, 1]), tensor(0))
#5

以上数据以tuple返回,每次只返回一个样本。实际上,Dateset只负责数据的抽取,调
用一次__getitem__只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和
并行加速等操作,可选择DataLoader。DataLoader的格式为:

data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate at 0x7f108ee01620>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
)

主要参数说明:

  • dataset:加载的数据集。
  • batch_size:批大小。
  • shuffle:是否将数据打乱。
  • sampler:样本抽样。
  • num_workers:使用多进程加载的进程数,0代表不使用多进程。
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快
    一些。
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多
    出来不足一个batch的数据丢弃。
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
print('i:',i)
Data,Label=traindata
print('data:',Data)
print('Label:',Label)

运行结果:

i: 0
data: tensor([[1, 2],
[3, 4]])
Label: tensor([0, 1])
i: 1
data: tensor([[2, 1],
[3, 4]])
Label: tensor([0, 1])
i: 2
data: tensor([[4, 5]])
Label: tensor([2])

从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,比如对它
进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器。

dataiter=iter(test_loader)
imgs,labels=next(dataiter)

一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,因为不同的目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就很不方便。不过,使用
PyTorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标
签,还提供很多数据预处理、数据增强等转换函数。

http://www.dtcms.com/a/339679.html

相关文章:

  • UE5 PCG 笔记(一)
  • C++ STL(标准模板库)学习
  • 华为鸿蒙系统SSH如何通过私钥连接登录
  • 传统概率信息检索模型:理论基础、演进与局限
  • 短剧小程序系统开发:打造沉浸式短剧观影体验
  • EPM240T100I5N Altera FPGA MAX II CPLD
  • Spring Cache 整合 Redis 实现高效缓存
  • idea如何设置tab为4个空格
  • 复习登录校验流程:会话跟踪技术与请求拦截方案详解
  • SpringBoot-集成POI和EasyExecl
  • 《Light Sci Appl》突破:vdW材料实现亚波长光学涡旋生成,转换效率达46%
  • 前端基础知识操作系统系列 - 01(操作系统的理解?核心概念有哪些)
  • Spring Ai Prompts
  • 佰力博检测与您探讨电晕极化时有时会击穿是什么原因
  • 海洋牧场智能化监控系统升级,保障养殖安全
  • Web3.0 时代的电商系统:区块链如何解决信任与溯源问题?
  • 嵌入式系统学习Day19(数据结构)
  • 用poll改写select
  • 网站频繁遭遇SQL注入、XSS攻击该怎么办?
  • 分布式搜索(Elasticsearch)深入用法
  • git 创用操作
  • java快速接入mcp以及结合mysql动态管理
  • 【SQL优化案例】统计信息缺失
  • 前端使用koa实现调取deepseekapi实现ai聊天
  • RabbitMQ:SpringAMQP Fanout Exchange(扇型交换机)
  • Apache ECharts 6.0.0 版本-探究自定义动态注册机制(二)
  • HTML5视频加密播放的主要优势
  • 本地存储(Local Storage)与Cookie的深度对比
  • RWA在DeFi中的应用
  • 行业分析---领跑汽车2025第二季度财报