Pytorch-03数据的Transform
Transforms
原始的数据形式可能并不符合模型算法所要求的输入形式。例如一个图片刚读入内存的时候可能还是numpy形式,而模型输入需要是tensor形式等等。亦或者模型要求的标签是one-hot,而现在的标签是整数等等的情况。
为了规范、统一的解决这个问题,pytorch定义了transform
来在Dataset的__getitem__
阶段,对数据进行处理,然后再交给Dataloader,再交给模型以供训练或推理。
在TorchVision中,所有的datasets
都有两个可以指定的transform参数:
transform
: 定义要怎么处理初始的features(数据样本)target_transform
:定义要怎么处理初始标签
torchvision.transforms模块提供了很多开箱即用的,常用的转换方法
对于FashionMNIST数据集,图片是PIL图片格式, 标签是整数形式,为了能进行分类训练,我们需要把图片转换成归一化之后的tensors,并且把整数标签转换为one-hot编码之后的tensor。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y; torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
ToTensor()
这个方法会把一个PIL图片或者ndarray转换成FloatTensor
, 并且把图片的像素值归一化到[0, 1]之间。
归一化对训练又很多好处,如加速训练,避免梯度爆炸或者梯度消失,让训练更加稳定。
Lambda Transforms
你可以用torchvision.transforms.Lambda
将任何简单的函数或 lambda 表达式作为转换器来使用。这里是定义了一个标签转换为one-hot编码tensor的匿名函数。
target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
lambda y
表明是一个匿名函数,接受一个参数y,也就是标签的整数值,然后利用torch.zeros(10, dtype=torch.float)
创建一个全0的一维张量,最后使用scatter_(dim=0. index=torch.tensor(y), value=1)
就地操作把自己index为y的元素赋值为1,这样就实现了one-hot编码。