Pytorch中的Transforms学习
1、Transforms介绍
在 PyTorch 的计算机视觉库 torchvision 中,transforms 模块是一个用于图像预处理和数据增强的核心工具。它提供了一系列预定义的函数和类,用于将原始图像转换为适合深度学习模型输入的张量格式,同时支持数据增强以提高模型的泛化能力。
1.1 主要功能
torchvision.transforms 的核心目标是:
- 图像预处理:将图像转换为张量(Tensor)并标准化。
- 数据增强:通过随机变换生成多样化的训练数据,防止过拟合。
- 灵活组合:通过 Compose 将多个变换操作串联成流水线。
1.2 常用操作分类
- 基础图像变换
- Resize(size)
调整图像尺寸(支持整数或 (H, W) 元组)。
transforms.Resize(256) # 将短边缩放到256,长边按比例调整
transforms.Resize((224, 224)) # 强制缩放到224x224
- CenterCrop(size) / RandomCrop(size)
中心裁剪或随机裁剪到指定尺寸。
transforms.RandomCrop(224) # 随机裁剪224x224区域
- RandomHorizontalFlip(p=0.5)
按概率 p 随机水平翻转图像。
transforms.RandomHorizontalFlip(p=0.5) # 50%概率翻转
- RandomRotation(degrees)
随机旋转图像(角度范围 [-degrees, degrees])。
transforms.RandomRotation(30) # 随机旋转-30°到30°
- 张量转换与标准化
- ToTensor()
将 PIL.Image 或 numpy.ndarray 转换为 torch.Tensor,并自动归一化到 [0, 1]。
transforms.ToTensor() # 输入形状 (H, W, C) → 输出 (C, H, W)
- Normalize(mean, std)
对张量进行标准化(按通道计算):
o u t p u t = ( i n p u t − m e a n ) s t d output= \frac{(input−mean)}{std} output=std(input−mean)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet统计值
- 颜色与对比度变换
- ColorJitter(brightness, contrast, saturation, hue)
随机调整亮度、对比度、饱和度和色调。
transforms.ColorJitter(brightness=0.2, contrast=0.2) # 亮度、对比度随机调整20%
- Grayscale(num_output_channels=1)
将图像转为灰度图。
transforms.Grayscale() # 输出单通道灰度图
- RandomGrayscale(p=0.1)
按概率 p 将图像转为灰度图。
transforms.RandomGrayscale(p=0.1) # 10%概率转灰度
- 组合变换(Compose)
通过 Compose 将多个变换按顺序组合成一个流水线:
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
- 自定义变换
对于复杂需求,可以使用 Lambda 或自定义函数:
# 使用Lambda定义简单变换
transforms.Lambda(lambda x: x.rotate(90)) # 旋转90度
# 自定义函数(需返回张量或PIL图像)
def custom_transform(img):
# 处理图像逻辑
return img
transform = transforms.Compose([
transforms.Resize(256),
custom_transform,
transforms.ToTensor(),
])
- 功能性变换(Functional Transforms)
torchvision.transforms.functional 提供了细粒度的函数式接口,适合自定义复杂逻辑(需手动处理随机性):
from torchvision.transforms import functional as F
def random_rotation(img):
angle = torch.randint(-30, 30, (1,)).item()
return F.rotate(img, angle)
transform = transforms.Compose([
transforms.Resize(256),
transforms.Lambda(random_rotation),
transforms.ToTensor(),
])
1.3 典型应用场景
- 训练时的数据增强:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize(...),
])
- 测试/推理时的预处理:
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(...),
])
- 与数据集结合:
from torchvision.datasets import CIFAR10
dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
1.4 注意事项
- 顺序敏感:某些操作必须按特定顺序执行(如 ToTensor 应在 Normalize 之前)。
- 数据类型:ToTensor 会将 uint8 图像转换为 float32 张量。
- 标准化参数:需根据数据集统计值调整 mean 和 std(例如 ImageNet 的默认值)。
1.5 总结
torchvision.transforms 是 PyTorch 中处理图像数据的核心模块,通过灵活的变换组合,可以实现:
- 规范化输入:适配模型输入格式。
- 增强多样性:提升模型鲁棒性。
- 高效流水线:简化数据预处理流程。
对于具体任务(如医学图像、卫星图像等),可以结合自定义变换扩展功能。
2、Transforms操作实例
2.1 实例一:将图片转换为tensor格式
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
# 创建一个 SummaryWriter 对象
writer = SummaryWriter("runs")
image_path = "hymenoptera_data/train/ants/0013035.jpg"
img_PIL = Image.open(image_path)
tensor_trans = transforms.ToTensor()
img_tensor = tensor_trans(img_PIL)
print(type(img_tensor))
print(img_tensor.shape)
writer.add_image("tensor_img", img_tensor)
writer.close()
程序使用了Tensorboard,和PIL中的Image模块,如果不熟悉,可以复习一下:
Pytorch中Tensorboard的学习
Pytorch中的数据加载
tensor_trans = transforms.ToTensor()
img_tensor = tensor_trans(img_PIL)
使用transforms中的ToTensor将PIL图像数据转为tensor类型的图像数据。
运行程序,打印信息:
<class 'torch.Tensor'>
torch.Size([3, 512, 768])
在终端执行命令:
Tensorboard --logdir=E:\my_pycharm_projects\project1\runs
返回结果:
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
打开网址:
2.2 实例二:
1、准备数据集
首先,将数据集放到项目所在文件夹中。数据集示例如下:
同样包含训练集(train)和验证集(val)。
不过,不同的是,例如在train文件夹下,图片数据和图片标签分别存放在不同的文件夹中(如ants_image和ants_label分别存放蚂蚁的图片和对应图片的标签)。
ants_image中为.ipg的图片文件,ants_label中为.txt文本文件,他们的名字是一一对应的。
文本文件中是该图片的标签。
2、数据加载
这里要用到torch.utils.data中的Dataset类,PIL中的Image模块以及os模块。可以复习一下之前的内容:
Pytorch中的数据加载
- 导入所需的类或模块
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms
Dataset: PyTorch的基础数据集类,自定义数据集需要继承它
DataLoader: 用于批量加载数据
Image: PIL库的图像处理模块
os: 用于文件路径操作
transforms: PyTorch的图像预处理工具
- 编写数据类
编写一个表示数据集的类,继承自Dataset类,并重写方法__init__()和方法__getitem__()和方法__len__():
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms
class MyData(Dataset):
def __init__(self, root_dir, image_dir, label_dir, transform=None):
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.image_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
self.transform = transform
# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
self.image_list.sort()
self.label_list.sort()
def __getitem__(self, idx):
img_name = self.image_list[idx]
label_name = self.label_list[idx]
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
img = Image.open(img_item_path)
with open(label_item_path, 'r') as f:
label = f.readline()
if self.transform:
img = transform(img)
return img, label
def __len__(self):
assert len(self.image_list) == len(self.label_list)
return len(self.image_list)
方法__getitem__()中有一个参数:transform,用于对加载的图像数据进行变换。
- 加载数据集
transform = transforms.Compose([
transforms.Resize(400), # 调整图像大小
transforms.ToTensor() # 转换为张量并归一化
])
分别读取ants和bees图像数据以及它们对应的标签,并对图像数据进行transform变换:
- transforms.Compose([])
这是一个容器,将多个图像变换操作按顺序组合成一个可调用的对象。输入图像会依次通过这些变换。 - transforms.Resize(400)
功能:将图像的短边调整为400像素(保持长宽比)。
如果参数是(H, W)(如(400, 300)),则强制调整为指定尺寸,可能破坏长宽比。 - transforms.ToTensor()
功能:将PIL图像或NumPy数组转换为PyTorch张量(torch.Tensor)。
自动将像素值从[0, 255]缩放到[0.0, 1.0](若输入是uint8类型)。
调整维度顺序为(C, H, W)(通道×高×宽)。
典型用途:
from PIL import Image
image = Image.open("image.jpg") # 加载图像
tensor = transform(image) # 应用变换:调整大小 → 转为张量
此时tensor是一个形状为(C, H, W)、值在[0, 1]范围内的浮点型张量,可直接输入神经网络。
transform = transforms.Compose([transforms.Resize(400), transforms.ToTensor()])
root_dir = "练手数据集/train"
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = MyData(root_dir, image_ants, label_ants, transform=transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = MyData(root_dir, image_bees, label_bees, transform=transform)
- 完整代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms
class MyData(Dataset):
def __init__(self, root_dir, image_dir, label_dir, transform=None):
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.image_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
self.transform = transform
# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
self.image_list.sort()
self.label_list.sort()
def __getitem__(self, idx):
img_name = self.image_list[idx]
label_name = self.label_list[idx]
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
img = Image.open(img_item_path)
with open(label_item_path, 'r') as f:
label = f.readline()
if self.transform:
img = transform(img)
return img, label
def __len__(self):
assert len(self.image_list) == len(self.label_list)
return len(self.image_list)
transform = transforms.Compose([transforms.Resize(400), transforms.ToTensor()])
root_dir = "练手数据集/train"
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = MyData(root_dir, image_ants, label_ants, transform=transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = MyData(root_dir, image_bees, label_bees, transform=transform)
通过索引该实例查看一下返回值:返回的是一个图像对象和一个标签:
>>> ants_dataset[1]
(tensor([[[0.6941, 0.7490, 0.8118, ..., 0.8353, 0.8118, 0.7922],
[0.6549, 0.7216, 0.7961, ..., 0.8471, 0.8392, 0.8392],
[0.5882, 0.6784, 0.7412, ..., 0.8549, 0.8667, 0.8784],
...,
[0.8824, 0.8706, 0.8549, ..., 0.8510, 0.8392, 0.8275],
[0.8549, 0.8627, 0.8549, ..., 0.8510, 0.8392, 0.8235],
[0.8314, 0.8392, 0.8275, ..., 0.8588, 0.8471, 0.8392]],
[[0.6784, 0.7373, 0.8196, ..., 0.8196, 0.7922, 0.7647],
[0.6275, 0.6941, 0.7804, ..., 0.8353, 0.8314, 0.8275],
[0.5843, 0.6510, 0.7137, ..., 0.8353, 0.8471, 0.8588],
...,
[0.8627, 0.8471, 0.8275, ..., 0.8196, 0.8118, 0.8039],
[0.8431, 0.8431, 0.8275, ..., 0.8235, 0.8157, 0.8078],
[0.8196, 0.8196, 0.8000, ..., 0.8392, 0.8314, 0.8235]],
[[0.5804, 0.6510, 0.7333, ..., 0.7098, 0.7020, 0.7412],
[0.5529, 0.5882, 0.6863, ..., 0.7647, 0.7647, 0.7804],
[0.4902, 0.5373, 0.6196, ..., 0.8000, 0.8078, 0.8196],
...,
[0.8039, 0.8039, 0.7765, ..., 0.7882, 0.7725, 0.7647],
[0.7843, 0.7961, 0.7765, ..., 0.7961, 0.7804, 0.7647],
[0.7843, 0.7922, 0.7647, ..., 0.8078, 0.7961, 0.7804]]]), 'ants')
用两个变量分别接收图片对象和其标签:
>>> image, label = ants_dataset[1]
>>> image.shape
torch.Size([3, 400, 600])