当前位置: 首页 > news >正文

timm教程翻译:(六)Data

https://timm.fast.ai/dataset

6.1 Dataset

timm 库中有三个主要的 Dataset 类:

  • ImageDataset
  • IterableImageDataset
  • AugMixDataset

在本文档中,我们将分别介绍它们,并探讨这些 Dataset 类的各种用例。

6.1.1 ImageDataset

class ImageDataset(root: str, parser: Union[ParserImageInTar, ParserImageFolder, str] = None,\class_map: Dict[str, str] = '', load_bytes: bool = False, \transform: List = None) -> Tuple[Any, Any]:

ImageDataset 可用于创建训练验证数据集,其功能与 torchvision.datasets.ImageFolder 非常相似,并带有一些不错的附加组件。

6.1.1.1 Parser

解析器使用 create_parser工厂方法自动设置。解析器root目录中查找所有图像和目标,其中root文件夹的结构如下:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.pngroot/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

解析器设置一个 class_to_idx 字典,将类映射到整数,如下所示:

{'dog': 0, 'cat': 1, ..}

还有一个名为samples的属性,它是一个元组列表,如下所示:

[('root/dog/xxx.png', 0), ('root/dog/xxy.png', 0), ..., \
('root/cat/123.png', 1), ('root/cat/nsdf3.png', 1), ...]

这个解析器对象是可下标的,执行类似 parser[index] 的操作时,它会返回 self.samples 中该索引处的样本。因此,执行类似 parser[0] 的操作将返回 ('root/dog/xxx.png', 0)

getitem(index: int) → Tuple[Any, Any]

一旦设置了解析器,ImageDataset 就会根据索引从该解析器中获取图像目标。

img, target = self.parser[index]

然后,它会将图像读取为 PIL.Image 并转换为 RGB 格式,或者根据 load_bytes 参数将图像读取为字节格式。

最后,它会转换图像并返回目标。如果目标为 None,则返回一个虚拟目标 torch.tensor(-1)。

6.1.1.2 用法

此 ImageDataset 也可以用作 torchvision.datasets.ImageFolder 的替代品。假设我们有一个 imagenette2-320 数据集,其结构如下:

imagenette2-320
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val├── n01440764├── n02102040├── n02979186├── n03000684├── n03028079├── n03394916├── n03417042├── n03425413├── n03445777└── n03888257

每个子文件夹包含一组属于该类的 .JPEG 文件。

# run only once
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
gunzip imagenette2-320.tgz
tar -xvf imagenette2-320.tar

然后,可以像这样创建一个 ImageDataset:

from timm.data.dataset import ImageDatasetdataset = ImageDataset('./imagenette2-320')
dataset[0]
(<PIL.Image.Image image mode=RGB size=426x320 at 0x7FF7F4880460>, 0)

我们还可以看到dataset.parser是ParserImageFolder的一个实例:

dataset.parser<timm.data.parsers.parser_image_folder.ParserImageFolder at 0x7ff7f4880d90>

最后,让我们看一下解析器中的 class_to_idx 字典映射:

dataset.parser.class_to_idx{'n01440764': 0,'n02102040': 1,'n02979186': 2,'n03000684': 3,'n03028079': 4,'n03394916': 5,'n03417042': 6,'n03425413': 7,'n03445777': 8,'n03888257': 9}

And, also, the first five samples like so:
此外,前五个样本如下:

dataset.parser.samples[:5][('./imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG', 0),('./imagenette2-320/train/n01440764/ILSVRC2012_val_00002138.JPEG', 0),('./imagenette2-320/train/n01440764/ILSVRC2012_val_00003014.JPEG', 0),('./imagenette2-320/train/n01440764/ILSVRC2012_val_00006697.JPEG', 0),('./imagenette2-320/train/n01440764/ILSVRC2012_val_00007197.JPEG', 0)]

6.1.2 IterableImageDataset

timm 还提供了与 PyTorch 的 IterableDataset 类似的 IterableImageDataset,但有一个关键区别 - IterableImageDataset 在生成图像和目标之前将transforms应用于图像。

Such form of datasets are particularly useful when data come from a stream or when the length of the data is unknown.
当数据来自流或数据长度未知时,这种形式的数据集特别有用。

timm 将transforms延迟应用于图像,并在目标为 None 时将目标设置为虚拟目标 torch.tensor(-1, dtype=torch.long)。

与上面的 ImageDataset 类似,IterableImageDataset 首先创建一个解析器,该解析器根据根目录获取一个样本元组。

如前所述,解析器返回一张图片,而目标是图片所在的对应文件夹。

注意:IterableImageDataset 没有定义 _getitem_ 方法,因此它不可下标。如果数据集是 IterableImageDataset 的一个实例,执行类似 dataset[0] 的操作将返回错误。

iter

The iter method inside IterableImageDataset first gets an image and a target from self.parser and then lazily applies the transforms to the image. Also, sets the target as a dummy value before both are returned.

IterableImageDataset 中的 _iter_ 方法首先从 self.parser 获取图像和目标,然后以惰性方式将变换应用于图像。此外,在返回两者之前,将目标设置为虚拟值。

6.1.2.1 Usage
from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform root = '../../imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0], next(iter(iterable_dataset))
((<_io.BufferedReader name='../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG'>,0),(<_io.BufferedReader name='../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG'>,0))

The iterable_dataset is not Subscriptable.

iterable_dataset[0]
> > 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-14-9085b17eda0c> in <module>
----> 1 iterable_dataset[0]~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)30 31     def __getitem__(self, index) -> T_co:---> 32         raise NotImplementedError     33 34     def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':NotImplementedError:

6.1.3 AugmixDataset

class AugmixDataset(dataset: ImageDataset, num_splits: int = 2):

AugmixDataset 接受 ImageDataset 并将其转换为 Augmix Dataset。

什么是 Augmix 数据集以及我们什么时候需要使用它?

Let’s answer that with the help of the Augmix paper.

Augmix

如上图所示,最终的损失输出实际上是 X o r i g X_{orig} Xorig X a u g m i x 1 X_{augmix1} Xaugmix1 X a u g m i x 2 X_{augmix2} Xaugmix2 上标签与模型预测之间的分类损失与 λ 乘以 Jensen-Shannon 损失之和。

因此,在这种情况下,我们需要该批次的三个版本——原始版本、augmix1 和 augmix2。那么我们如何实现呢?当然是使用 AugmixDataset!
注意:augmix1 和 augmix2 是原始批次的增强版本,其中的增强操作是从操作列表中随机选择的。

_getitem_(index: int) -> Tuple[Any, Any]

首先,我们从 self.dataset(传入 AugmixDataset 构造函数的数据集)中获取 X 及其对应的标签 y。接下来,我们对图像 X 进行归一化,并将其添加到名为 x_list 的变量中。

接下来,基于 num_splits 参数(默认值为 0),对 X 进行增强操作,对增强后的输出进行归一化,并将其添加到 x_list 中。
注意:如果 num_splits=2,则 x_list 包含两项:原始数据 + 增强后数据。如果 num_splits=3,则 x_list 包含三项:原始数据 + 增强后数据 1 + 增强后数据 2。依此类推。

6.1.3.2 Usage
from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loaderdataset = ImageDataset('../../imagenette2-320/')
dataset = AugMixDataset(dataset, num_splits=2)loader_train = create_loader(dataset, input_size=(3, 224, 224), batch_size=8, is_training=True, scale=[0.08, 1.], ratio=[0.75, 1.33], num_aug_splits=2
)
# Requires GPU to worknext(iter(loader_train))[0].shape>> torch.Size([16, 3, 224, 224])

注意:现在你可能会问——我们传入了 batch_size=8,但 loader_train 返回的批次大小却是 16?为什么会这样?
因为我们传入了 num_aug_splits=2。在这种情况下,loader_train 包含前 8 张原始图像和接下来 8 张代表 augmix1 的图像。

如果我们传入了 num_aug_splits=3,那么实际的 batch_size 应该是 24,其中前 8 张图像是原始图像,接下来 8 张代表 augmix1,最后 8 张代表 augmix2。

6.2 DataLoaders

timm DataLoader 与 torch.utils.data.DataLoader 略有不同,速度更快。让我们在这里探索一下。

在 timm 中创建数据加载器最简单的方法是调用 timm.data.loader 中的 create_loader 函数。该函数需要一个数据集对象、一个 input_size 参数以及一个 batch_size 参数。其他所有参数都已预设,方便我们操作。让我们看一个使用 timm 创建数据加载器的快速示例。

6.2.1 用法示例

!tree ../../imagenette2-320/ -d
../../imagenette2-320/
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val├── n01440764├── n02102040├── n02979186├── n03000684├── n03028079├── n03394916├── n03417042├── n03425413├── n03445777└── n0388825722 directories
from timm.data.dataset import ImageDatasetdataset = ImageDataset('../../imagenette2-320/')
dataset[0]
(<PIL.Image.Image image mode=RGB size=426x320 at 0x7F8379C26190>, 0)

太棒了,我们已经创建了数据集。timm 中的 ImageDataset 与 torchvision.datasets.ImageFolder 非常相似,但增加了一些不错的功能。让我们可视化一下数据集中的第一张图片。不出所料,这是一张丁鲷的图片!😉
注意:默认情况下,上面创建的数据集用于训练文件夹,因此我们可以将其称为训练数据集。

from matplotlib import pyplot as plt# visualize image
plt.imshow(dataset[0][0])
<matplotlib.image.AxesImage at 0x7f83702a7bd0>

在这里插入图片描述

Let’s now create our DataLoader.

from timm.data.loader import create_loadertry:# only works if gpu present on machinetrain_loader = create_loader(dataset, (3, 224, 224), 4)
except:train_loader = create_loader(dataset, (3, 224, 224), 4, use_prefetcher=False)

你可能会问,为什么上面要用 try-except 块?第一个 train_loader 和第二个 train_loader 有什么区别? use_prefetcher 参数是什么?它有什么作用?

6.2.2 Prefetch loader

timm 内部有一个名为 PrefetchLoader 的类。默认情况下,我们使用这个预取加载器prefetch loader来创建数据加载器。但是,它只适用于支持 GPU 的机器。由于我可以使用 GPU,所以 train_loader 对我来说就是这个 PrefetchLoader 类的一个实例。

train_loader
<timm.data.loader.PrefetchLoader at 0x7f836fd8c9d0>

注意:如果您在一台只有 CPU 的机器上运行此笔记本,train_loader 将是 torch.utils.dataloader 的一个实例。

现在让我们看看这个 PrefetchLoader 做了什么?所有有趣的部分都发生在这个类的_iter_ 方法中。

def __iter__(self):stream = torch.cuda.Stream()first = Truefor next_input, next_target in self.loader:with torch.cuda.stream(stream):next_input = next_input.cuda(non_blocking=True)next_target = next_target.cuda(non_blocking=True)if self.fp16:next_input = next_input.half().sub_(self.mean).div_(self.std)else:next_input = next_input.float().sub_(self.mean).div_(self.std)if self.random_erasing is not None:next_input = self.random_erasing(next_input)if not first:yield input, targetelse:first = Falsetorch.cuda.current_stream().wait_stream(stream)input = next_inputtarget = next_targetyield input, target

让我们试着理解一下到底发生了什么?我们只需要了解 cuda.streams 就能理解 PrefetchLoader 中的 _iter_ 方法。

PyTorch 上的文档写道:

CUDA 流是属于特定设备的线性执行序列。通常无需显式创建:默认情况下,每个设备都使用自己的“默认”流。每个流中的操作都按照创建顺序序列化,但不同流中的操作可以按任意相对顺序并发执行,除非使用显式同步函数\
(例如 synchronize() 或 wait_stream())。当“当前流”是默认流时,PyTorch 会在数据移动时自动执行必要的同步。但是,当使用非默认流时,用户有责任确保正确的同步。

简而言之,每个 CUDA 设备都可以拥有自己的“流”,即按顺序运行的命令序列。但这并不意味着所有流(如果存在多个 CUDA 设备)都彼此同步。可能的情况是,当命令 1 在第一个 CUDA 设备的“流”上运行时,命令 3 可能在第二个 CUDA 设备的“流”上运行。

但这有什么关系呢?“流”可以用来加快数据加载器的速度吗?

当然!这就是重点!基本上,用 Ross 的话来说,PrefetchLoader 背后的关键思想是:

“使用异步 CUDA 传输进行预取prefetch有助于稍微降低批量传输到 GPU 时卡顿的可能性,因为它(希望)可以更快地启动它,并使其更灵活地在自己的 CUDA 流中与其他操作同时运行。”

基本上,我们在设备自己的“流”中执行“迁移到 CUDA”步骤,而不是在默认流中。这意味着此步骤可以异步执行,而其他一些操作可能正在 CPU 或默认“流”中进行。这有助于加快速度,因为现在 CUDA 上的数据可以更快地传递到模型中。

这就是_iter_ 方法内部的工作。

对于第一个批次,我们像在 torch.utils.data.DataLoader 中一样迭代加载器,并返回输入目标。

但是,对于之后的每个批次,我们首先使用 torch.cuda.stream(stream): 为 CUDA 设备实例化一个“流”,接下来,我们以异步方式在该设备自己的“流”内执行 CUDA 传输,并返回 next_input 和 next_target。

因此,每次迭代数据加载器时,我们实际上都会返回一个预取的输入和目标,因此得名 PrefetchLoader。

http://www.dtcms.com/a/511483.html

相关文章:

  • VSCode + AI Agent实现直接编译调试:告别Visual Studio的原理与实践
  • 【设计模式】建造者模式(Builder)
  • DeepSeek-OCR:把长文本“挤进图片”的新思路
  • 计算机做网站开题报告网页的六个基本元素
  • AI服务器工作之整机部件(CPU+内存)
  • 【EE初阶 - 网络原理】网络层 + 数据链路层 + DNS
  • 关于二级网站建设西安网站制作一般多少钱
  • 【机器学习06】神经网络的实现、训练与向量化
  • [人工智能-大模型-25]:大模型应用层技术栈 - 大模型应用层的四大开发模式(如何利用大语言模型?)
  • YOLO目标检测:一种用于无人机的新型轻量级目标检测网络
  • 第六部分:VTK进阶(第166章 标量-向量-张量场管理)
  • A Survey of Camouflaged Object Detection and Beyond论文阅读笔记
  • 基于 hexo + github 的个人博客系统搭建
  • 成都私人做网站建设自由做图网站
  • 哈尔滨做网站找哪家好网站的在线支付怎么做
  • 使用pem和key文件给springboot开启https服务
  • XSS攻击防护完整指南
  • 基于Spring Boot的高校实习实践管理系统(源码+论文+部署+安装)
  • 第11篇:源码解析:Jackson核心流程与设计模式
  • 数据库原理实验报告:在ider里搭建mysql数据库
  • 面试(四)——Java 八大包装类、String 、日期类及文件操作核心类 File全解析
  • 【无标题】大模型-7种大模型微调方法 上
  • 信用网站系统建设方案阿里云服务器建设网站选择那个镜像
  • 大型的PC网站适合vue做吗网页制作工具通常在什么上建立热点
  • C++字符串操作与递增递减运算符详解
  • Python 的基本数据类型与它们之间的关系
  • All in One Runtimes下载和安装图解(附安装包,适合新手)
  • Python多patch装饰器使用指南
  • Prometheus监控系统
  • 【Java-集合】Set接口