transforms学习笔记
视频来源:https://www.bilibili.com/video/BV1hE411t7RN?t=6.1&p=13
一、什么是 transforms?
transforms
是PyTorch提供的图像预处理工具集,主要用于:
- 数据格式转换(如 PIL 图像转 Tensor)
- 图像增强(如缩放、裁剪、旋转等)
- 标准化处理(如均值方差归一化)
- 数据增强(如随机裁剪、翻转等,用于防止过拟合)
1. ToTensor()
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
- 功能:将PIL图像或numpy数组转换为 PyTorch 的 Tensor
- 转换细节:
- 会将图像的像素值从 [0, 255] 范围归一化到 [0, 1]
- 会调整维度顺序,从 (H, W, C) 转换为 (C, H, W)
- 是绝大多数图像处理流程的第一步
2. Normalize()
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
- 功能:对 Tensor 进行标准化处理
- 计算公式:
output = (input - mean) / std
- 参数说明:
- 第一个参数是每个通道的均值
- 第二个参数是每个通道的标准差
- 常见用法:使用 ImageNet 数据集的均值和标准差
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
3. Resize()
trans_resize = transforms.Resize((224, 224)) # 固定尺寸
trans_resize_2 = transforms.Resize(512) # 按比例缩放,短边为512
- 功能:调整图像尺寸
- 参数说明:
- 可以传入元组
(height, width)
指定固定尺寸 - 可以传入单个整数,此时会按比例缩放,短边会被调整为该值
- 可以传入元组
- 注意:Resize 作用于 PIL 图像,而不是 Tensor
4. Compose()
trans_compose = transforms.Compose([trans_resize_2, transforms.ToTensor()])
- 功能:将多个变换组合成一个序列
- 执行顺序:按照列表中的顺序依次执行
- 注意事项:前一个变换的输出必须符合后一个变换的输入要求
5. RandomCrop()
trans_random = transforms.RandomCrop((400, 150))
- 功能:随机裁剪图像的指定尺寸区域
- 特点:每次调用会产生不同的裁剪结果
- 用途:数据增强,增加训练样本的多样性
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open(r"E:\pycharm\learn_pytorch\date\train\shot\c39e6ff6-d65f-4a9b-b9ef-bf87e52b2665.png")# 将RGBA图像转换为RGB
if img.mode == 'RGBA':img = img.convert('RGB')print(f"图像尺寸: {img.size}") # (宽, 高)
print(img)# 1.ToTensor使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("ToTensor", img_tensor)# 2.Normalize
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm, 1)# 3.Resize
print(f"调整前图像尺寸: {img.size}")
trans_resize = transforms.Resize((224, 224))
img_resize_pil = trans_resize(img)
img_resize_tensor = trans_totensor(img_resize_pil)
writer.add_image("Resize", img_resize_tensor, 0)
print(img_resize_tensor)# Compose - resize - 2
trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize_2, transforms.ToTensor()])
img_resize_2 = trans_compose(img)
writer.add_image("Resize", img_resize_2, 1)# RandomCrop - 修复尺寸问题
# 使用适合原始图像大小的裁剪尺寸
# 原始图像尺寸是(181, 626),我们选择(150, 400)的裁剪尺寸
trans_random = transforms.RandomCrop((400, 150)) # (高, 宽),注意顺序与图像size的(宽,高)相反
trans_compose_2 = transforms.Compose([trans_random, transforms.ToTensor()])
for i in range(10):img_crop = trans_compose_2(img)writer.add_image("RandomCrop", img_crop, i)writer.close()
PS E:\pycharm\learn_pytorch> tensorboard --logdir=logs
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.20.0 at http://localhost:6006/ (Press CTRL+C to quit)
二、常用 transforms 拓展
1. 图像翻转
# 随机水平翻转,概率为0.5
trans_hflip = transforms.RandomHorizontalFlip(p=0.5)
# 随机垂直翻转,概率为0.5
trans_vflip = transforms.RandomVerticalFlip(p=0.5)
2. 旋转与角度变换
# 随机旋转(-30, 30)度
trans_rotate = transforms.RandomRotation(30)
# 随机选择给定角度中的一个进行旋转
trans_rotate_2 = transforms.RandomRotation(degrees=[90, 180, 270])
3. 亮度、对比度、饱和度调整
# 随机调整亮度
trans_brightness = transforms.ColorJitter(brightness=0.5)
# 综合调整亮度、对比度、饱和度和色调
trans_color = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2)
4. 中心裁剪
# 从中心裁剪指定尺寸
trans_center_crop = transforms.CenterCrop((224, 224))
5. 随机擦除
# 在图像上随机选择区域进行擦除(填充0或随机值)
trans_erase = transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))
6. 转换为灰度图
# 将图像转换为灰度图
trans_gray = transforms.Grayscale(num_output_channels=1)
三、transforms 使用最佳实践
1.训练集和验证集使用不同的 transforms:
- 训练集:使用更多的数据增强(随机裁剪、翻转等)
- 验证集:只使用必要的转换(Resize、ToTensor、Normalize)
2.典型的训练集转换管道:
train_transform = transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪并调整大小transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ColorJitter(brightness=0.2, contrast=0.2), # 随机调整亮度和对比度transforms.ToTensor(), # 转换为Tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
3.典型的验证集转换管道:
val_transform = transforms.Compose([transforms.Resize(256), # 调整大小transforms.CenterCrop(224), # 中心裁剪transforms.ToTensor(), # 转换为Tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
4.与 DataLoader 结合使用:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoadertrain_dataset = ImageFolder(root='path/to/train', transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataset = ImageFolder(root='path/to/val', transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
通过合理使用 transforms,可以有效提高模型的泛化能力,尤其是在训练数据有限的情况下,数据增强技术显得尤为重要。