当前位置: 首页 > news >正文

pytorch 数据处理

torch工具类Dataset和DataLoader

对于NN模型训练来说,需要将数据转换成torch识别的数据类型,才能喂给模型。pytorch中,通常使用Dataset和DataLoader这两个工具类来构建数据管道。

  • Dataset定义了数据集的内容,类似一个列表的数据结构,有确定的长度,能够用索引获取数据集中的元素。
  • DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
batch_size: how many samples per batch to load
shuffle: set to ``True`` to have the data reshuffled at every epoch (default: ``False``).
drop_last: set to ``True`` to drop the last incomplete batch

自定义DataSet都需要集成DataSet父类,复写 __init__,__getitem__和__len__方法。

from numpy.ma.core import shape
from torch.utils.data import Dataset
import torch


class MyDataset(Dataset):
    def __init__(self, dataList, labelList):
        self.dataList = dataList
        self.labelList = labelList


    def __getitem__(self, idx):
        return self.dataList[idx], self.labelList[idx]


    def __len__(self):
        return len(self.labelList)


dataList, labelList = torch.randn(1000,3),torch.randint(low=0, high=2, size=(1000,)).float()
dataset_test = MyDataset(dataList, labelList)

用DataLoader读取Dataset的数据

dl = DataLoader(dataset_test, batch_size=4, drop_last=True)
data, label = next(iter(dl))
print("data=", data)
print("label=", label)

Dataset的创建方法

Dataset创建数据集常用的方法有:

  • 继承 torch.utils.data.Dataset 创建自定义数据集,如上;
  • 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集;
  • 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。

使用torchvision提供的数据集

数据集地址:# https://pytorch.org/vision/stable/datasets.html#built-in-datasets

from torch.utils.data import TensorDataset,DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

# 下载训练集
train_set = torchvision.datasets.CIFAR10(root="./trainset", train=True, download=True)
# 下载测试集
test_set = torchvision.datasets.CIFAR10(root="./trainset", train=False, download=True)
# 查看数据类型
print(test_set[0])
print(test_set.classes)


# 做数据转换,从PIL>tensor
dataset_compose = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 利用tensorboard展示浏览图像
writer = SummaryWriter("cifar10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", dataset_compose(img), i)

writer.close()

DataLoader的使用

test_set = torchvision.datasets.CIFAR10(root="./trainset", train=False, download=True, transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_set, batch_size=4, shuffle=True, drop_last=False)
for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)

通过dataloader可一次性从数据集中取多少数据,并且可以设定采样情况。

相关文章:

  • Spring Boot 动态定时任务:实现与应用详解
  • XILINX AXI总线
  • windows AWTK开发环境搭建
  • 解决 启动模拟器出现 未开启Hyper-V 的问题
  • 退出登录时如何使JWT令牌失效?
  • 学习日志8.21--防火墙NAT
  • led台灯对眼睛好不好?护眼台灯怎么选对眼睛好?收下这份攻略
  • C#过 SemaphoreSlim 实现高效的数据库并发控制和资源管理(多线程)
  • React 入门第八天:性能优化与开发者工具的使用
  • python 实现一个简单的网页爬虫程序
  • Python编程的特点
  • 一文教你编写有效提示词,了解常用提示词工具—Prompt Engineering for Gen AI
  • 解决MAC电脑SVN Android studio不能提交.so文件相关
  • python创建虚拟环境并在pycharm引用
  • 网络安全售前入门05安全服务——渗透测试服务方案
  • 【软件文档】项目总结报告编制模板(Word原件参考)
  • hdfs的慢盘检测
  • Nacos2.4.1安装
  • Stable Diffusion详解
  • Javaweb学习之Vue数据绑定(五)
  • 王毅谈金砖国家开展斡旋调解的经验和独特优势
  • 鲁迅先生儿媳、周海婴先生夫人马新云女士逝世,享年94岁
  • “80后”蒋美华任辽宁阜新市副市长
  • 牛市早报|今年第二批810亿元超长期特别国债资金下达,支持消费品以旧换新
  • “75后”袁达已任国家发改委秘书长
  • 观察|英国航母再次部署印太,“高桅行动”也是“高危行动”