当前位置: 首页 > news >正文

深度学习(4):数据加载器

一、Dataset:数据集类

1.数据集类需要继承Dataset类

2.实现__init__方法,数据初始化

3.实现__len__方法,返回数据集的长度

4.实现__getitem__方法,根据索引下标获取数据

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsclass MyDataset(Dataset):def __init__(self,data,labels):assert len(data) == len(labels)self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self,index):sample = self.data[index]label = self.labels[index]return sample,label

二、DataLoader:数据加载器

返回一个迭代器

参数:

dataset:要加载的数据集

batch_size:每批次读取的样本数量

shuffle:是否打乱顺序,True-打乱,False-不打乱

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsx = torch.randn(1000, 20)
y = torch.randn(1000, 10)dataset = MyDataset(x, y)
print(len(dataset))#1000
print(dataset[0])
"""
(tensor([ 0.1911, -0.0872,  0.4112, -0.3616, -2.4566, -0.5119,  0.1298,  1.0090,-0.6610, -1.3058,  0.1351, -1.6622,  0.8579, -0.5143,  0.6540, -0.0464,0.4354, -0.1966, -0.1209,  0.2876]), tensor([ 1.8922,  1.4897, -1.4169, -1.2283, -0.9311, -0.7850,  0.9580,  0.3025,0.3257, -0.3441]))
"""dataloader = DataLoader(dataset=dataset,batch_size=100,shuffle=True
)for x, y in dataloader:
print(x.shape, y.shape)#torch.Size([100, 20]) torch.Size([100, 10])break

三、TensorDataset: torch提供的dataset类

如果对数据没有特殊处理的情况下,可以考虑使用TensorDataset

如果需要对数据进行特殊处理,可以考虑自定义Dataset数据集

    x = torch.randn(1000, 20)y = torch.randn(1000, 10)dataset = TensorDataset(x, y)dataloader = DataLoader(dataset=dataset,batch_size=100,shuffle=True)for x, y in dataloader:print(x.shape, y.shape)#torch.Size([100, 20]) torch.Size([100, 10])break

四、自定义图片加载器

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsfilepath = './datasets/animals'transform = transforms.Compose([transforms.Resize(size=(224, 224)),transforms.ToTensor()
])dataset = datasets.ImageFolder(filepath, transform=transform)
dataloader = DataLoader(dataset=dataset,batch_size=20,shuffle=True
)for x, y in dataloader:print(x, y)break

五、加载MNIST数据集

# MNIST数据集:黑底白字的手写数字,图片分辨率:28*28
# 分训练数据集(60000)和测试数据集(10000)
def test05():transform = transforms.Compose([transforms.ToTensor()])# train: 是否为训练数据集# root:保存数据集的路径# transform:图片转换器train_dataset = datasets.MNIST(root='./datasets',train=True,download=True,transform=transform)dataloader = DataLoader(dataset=train_dataset,batch_size=20,shuffle=True)# 按批次遍历,每批次读取batch_size个数据for x, y in dataloader:print(x, y)break
"""
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],...,[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([9, 2, 2, 0, 9, 1, 3, 7, 2, 5, 1, 8, 8, 8, 6, 2, 7, 6, 4, 2])
"""

http://www.dtcms.com/a/329776.html

相关文章:

  • Redis7学习——Redis的初认识
  • 51c自动驾驶~合集14
  • Docker:快速部署 Temporal 工作流引擎的技术指南
  • 3DM游戏运行库合集离线安装包下载, msvcp140.dll丢失等问题修复
  • 迅雷链接在线解密解析工具系统源码/本地化API/开源
  • 前缀函数的运用
  • Harmony OS 开发入门 第三章
  • Python Day29 CSS样式
  • Protobuf学习(1)—— 初识与安装
  • 代理解决跨域
  • SparseArray ArrayMap
  • Activity和Fragment生命周期
  • Spring进阶(八股篇)
  • 栈和队列详解
  • LeetCode刷题记录----437.路径总和Ⅲ(medium)
  • 学习:JS进阶[10]内置构造函数
  • HunyuanVideo-Avatar:为多个角色制作高保真音频驱动的人体动画
  • C++哈希进阶-位图
  • 计算机网络技术-知识篇(Day.1)
  • java14学习笔记-打包工具 (Incubator)
  • MoonBit Perals Vol.05: 函数式里的依赖注入:Reader Monad
  • JPrint免费的Web静默打印控件:PDF打印中文乱码异常解决方案
  • 什么是JSP和Servlet以及二者的关系
  • window显示驱动开发—多平面覆盖 VidPN 呈现
  • MVCC底层实现原理
  • Flask入门:从零搭建Web服务器
  • 雅思大作文笔记
  • iOS 签名证书在版本迭代和iOS上架中的全流程应用
  • Docker 在 Linux 中的额外资源占用分析
  • 智汇河套,量子“风暴”:量子科技未来产业发展论坛深度研讨加速产业成果转化