当前位置: 首页 > news >正文

PyTorch图像预处理全解析(transforms)

1. 引言

在深度学习计算机视觉任务中,数据预处理和数据增强是模型训练的关键步骤,直接影响模型的泛化能力和最终性能表现。PyTorch 提供的 torchvision.transforms 模块,封装了丰富的图像变换方法,能够高效地完成图像标准化、裁剪、翻转等操作。该模块支持两种主要的使用方式:单步变换(Single Transform)和组合变换(Compose),可以灵活应对不同场景下的图像处理需求。

本文将详细解析 transforms 的核心 API、参数含义,并通过完整代码示例演示其使用方法。主要内容包括:

  1. 基础变换操作

    • 尺寸调整:Resize(target_size)
    • 随机裁剪:RandomCrop(size, padding=None, pad_if_needed=False)
    • 中心裁剪:CenterCrop(size)
    • 随机水平/垂直翻转:RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5)
  2. 颜色空间变换

    • 颜色抖动:ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
    • 随机灰度化:RandomGrayscale(p=0.1)
    • 高斯模糊:GaussianBlur(kernel_size, sigma=(0.1, 2.0))
  3. 数据标准化

    • 归一化:Normalize(mean, std)
    • 张量转换:ToTensor()
  4. 实用组合方法

    • 变换链:Compose([transforms1, transforms2,...])
    • 随机选择:RandomApply(transforms, p=0.5)
    • 随机排序:RandomOrder(transforms)

以图像分类任务为例,一个典型的数据增强流程可能如下:

from torchvision import transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

其中,训练集使用更丰富的增强策略以提高模型鲁棒性,而验证集则采用较简单的预处理保持数据原始分布。通过合理配置这些变换参数,可以显著提升模型在各种视觉任务(如图像分类、目标检测、语义分割等)中的表现。


2. transforms 概述

transforms 是 PyTorch 生态系统中 torchvision 库的核心模块之一,专门用于计算机视觉任务中的图像数据处理。它提供了丰富的图像变换方法,主要分为三大类功能:

  1. 图像预处理

    • 尺寸调整:transforms.Resize() 可将图像统一缩放到指定尺寸(如 256x256)
    • 归一化:transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 使用 ImageNet 的均值和标准差进行标准化
    • 中心裁剪:transforms.CenterCrop(224) 从图像中心裁剪出指定大小的区域
  2. 数据增强(常用于训练阶段防止过拟合):

    • 随机裁剪:transforms.RandomCrop(224) 在随机位置裁剪
    • 颜色变换:transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    • 随机水平翻转:transforms.RandomHorizontalFlip(p=0.5)
    • 随机旋转:transforms.RandomRotation(degrees=15)
  3. 格式转换

    • PIL图像转张量:transforms.ToTensor() 将图像转换为 PyTorch 张量(并自动将像素值归一化到 [0,1])
    • 张量转PIL图像:transforms.ToPILImage()

组合使用示例

from torchvision import transforms# 训练阶段的变换流水线
train_transform = transforms.Compose([transforms.Resize(256),              # 缩放至256x256transforms.RandomCrop(224),          # 随机裁剪224x224transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.ToTensor(),               # 转为张量transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化std=[0.229, 0.224, 0.225])
])# 验证阶段的变换流水线(通常不包含随机增强)
val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

在实际应用中,这些变换可以显著提升模型的泛化能力,特别是在数据量有限的情况下。对于不同的计算机视觉任务(如图像分类、目标检测等),可以根据具体需求组合不同的变换操作。


3. 核心 API 详解

3.1 基础变换

(1) Resize(size)
  • 功能:调整图像尺寸。

  • 参数

    • size (int or tuple):目标尺寸。如果是 int,短边缩放至该值,长边按比例调整;如果是 (h, w),则强制缩放到指定大小。

  • 示例

transform = transforms.Resize(256)  # 短边缩放到256,长边按比例调整
transform = transforms.Resize((224, 224))  # 强制缩放到224x224
(2) CenterCrop(size)
  • 功能:从图像中心裁剪指定大小的区域。

  • 参数

    • size (int or tuple):裁剪尺寸(int 表示正方形,(h, w) 表示矩形)。

  • 示例

transform = transforms.CenterCrop(224)  # 裁剪224x224的正方形
(3) RandomCrop(size)
  • 功能:随机位置裁剪图像。

  • 参数

    • size (int or tuple):裁剪尺寸。

    • padding (int or tuple, optional):填充边缘(防止裁剪过小)。

  • 示例

transform = transforms.RandomCrop(224, padding=10)  # 随机裁剪224x224,边缘填充10像素
(4) RandomHorizontalFlip(p=0.5)
  • 功能:以概率 p 水平翻转图像(默认 p=0.5)。

  • 示例

transform = transforms.RandomHorizontalFlip(p=0.7)  # 70%概率水平翻转
(5) RandomRotation(degrees)
  • 功能:随机旋转图像。

  • 参数

    • degrees (float or tuple):旋转角度范围(如 30 表示 [-30°, 30°](10, 30) 表示 [10°, 30°])。

  • 示例

transform = transforms.RandomRotation(30)  # 随机旋转 ±30°

3.2 张量转换 & 标准化

(1) ToTensor()
  • 功能

    • 将 PIL.Image 或 numpy.ndarray 转换为 torch.Tensor[C, H, W] 格式)。

    • 像素值从 [0, 255] 缩放到 [0.0, 1.0]

  • 示例

transform = transforms.ToTensor()  # 转换为张量
(2) Normalize(mean, std)
  • 功能:对张量进行标准化(逐通道计算:(x - mean) / std)。

  • 参数

    • mean (list):各通道均值(如 ImageNet 的 [0.485, 0.456, 0.406])。

    • std (list):各通道标准差(如 ImageNet 的 [0.229, 0.224, 0.225])。

  • 示例

transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)

3.3 颜色变换

(1) ColorJitter
  • 功能:随机调整亮度、对比度、饱和度和色相。

  • 参数说明:

  • brightness (float 或 tuple):亮度调整范围

    • 当输入为单个浮点数时(如 0.2),表示亮度调整范围为 [1-0.2, 1+0.2] = [0.8, 1.2]
    • 当输入为元组时(如 (0.7, 1.3)),表示直接指定亮度调整范围
    • 示例:brightness=0.5 表示图片亮度将在原始值的50%-150%之间随机调整
  • contrast (float 或 tuple):对比度调整范围

    • 调节方式与brightness相同
    • 示例:contrast=(0.8, 1.5) 表示对比度将在原始值的80%-150%之间随机调整
    • 应用场景:

    • 这些参数常用于图像增强和数据增强任务
    • 在训练深度学习模型时,随机调整这些参数可以增加训练数据的多样性
    • 每个参数的调整都是在指定范围内随机取值,而不是固定值
    • saturation (float 或 tuple):饱和度调整范围

      • 调节方式与brightness相同
      • 示例:saturation=0.3 表示饱和度将在原始值的70%-130%之间随机调整
    • hue (float 或 tuple):色相调整范围

      • 当输入为单个浮点数时(如 0.1),表示色相调整范围为 [-0.1, 0.1]
      • 当输入为元组时(如 (-0.2, 0.3)),表示直接指定色相调整范围
      • 注意:色相值通常以弧度表示,范围一般为[-0.5, 0.5]
      • 示例:hue=0.05 表示色相将在[-0.05, 0.05]范围内随机调整
  • 示例

transform = transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1
)
(2) Grayscale(num_output_channels=1)
  • 功能:将图像转为灰度图。

  • 参数

    • num_output_channels:输出通道数(1 或 3)。

  • 示例

transform = transforms.Grayscale(num_output_channels=3)  # 转为3通道灰度图

4. 完整代码示例

4.1 定义训练和测试的变换

from torchvision import transforms# 训练集变换(含数据增强)
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),      # 随机缩放裁剪至224x224transforms.RandomHorizontalFlip(),      # 50%概率水平翻转transforms.ColorJitter(                 # 随机颜色调整brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),                 # 转为张量 [C, H, W], 值范围[0, 1]transforms.Normalize(                  # 标准化(ImageNet参数)mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 测试集变换(仅预处理)
test_transform = transforms.Compose([transforms.Resize(256),                # 短边缩放到256transforms.CenterCrop(224),            # 中心裁剪224x224transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

4.2 应用到数据集 

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 加载CIFAR10数据集(应用变换)
train_dataset = CIFAR10(root='./data', train=True, transform=train_transform,  # 应用训练变换download=True
)test_dataset = CIFAR10(root='./data', train=False, transform=test_transform,   # 应用测试变换download=True
)# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

5. 总结

使用 Compose 可以方便地组合多个变换操作,这些变换会按照添加顺序依次执行。例如:

transforms.Compose([transforms.Resize(256),          # 调整图像大小transforms.RandomCrop(224),      # 随机裁剪transforms.ToTensor(),           # 转换为张量transforms.Normalize(            # 标准化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

在实际应用中,训练和测试阶段通常采用不同的转换策略:

标准化(Normalize)是一个关键步骤,它能:

当使用预训练模型时,应该采用该模型训练时使用的均值和标准差(常见的是 ImageNet 的统计值:mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])。如果不使用预训练模型,可以计算自己数据集的统计值进行标准化。

  • PyTorch 中的 transforms 模块是计算机视觉任务中图像处理的核心工具,它提供了一系列用于图像预处理、数据增强和数据类型转换的功能。这些转换操作可以高效地将原始图像数据转换为适合深度学习模型训练的格式。

    主要功能包括:

  • 预处理:如图像大小调整(Resize)、中心裁剪(CenterCrop)、转换为张量(ToTensor)等基础操作
  • 数据增强:训练时增加数据多样性的随机变换,如随机水平翻转(RandomHorizontalFlip)、随机旋转(RandomRotation)
  • 张量转换:将 PIL 图像或 numpy 数组转换为 PyTorch 张量,并进行数值归一化等操作
  • 训练阶段:建议使用数据增强来提升模型泛化能力,常用增强方法包括:
    • RandomCrop:随机裁剪图像
    • ColorJitter:随机调整亮度、对比度、饱和度
    • RandomHorizontalFlip:随机水平翻转
    • RandomRotation:随机旋转
  • 测试阶段:通常只需基础预处理,如固定大小的裁剪和标准化
  • 将输入数据缩放到相近的数值范围
  • 加速模型收敛过程
  • 提高训练稳定性

掌握 transforms 的使用,可以显著提升计算机视觉任务的效率和模型性能! 

http://www.dtcms.com/a/285073.html

相关文章:

  • SAP-ABAP:SAP的‘cl_http_utility=>escape_url‘对URL进行安全编码方法详解
  • 6 基于STM32单片机的智能家居系统设计(STM32代码编写+手机APP设计+PCB设计+Proteus仿真)
  • 如何从 iPhone 向Mac使用 AirDrop 传输文件
  • 企业网络运维进入 “AI 托管” 时代:智能分析 + 自动决策,让云、网、端一眼看穿
  • 关于用git上传远程库的一些常见命令使用和常见问题:
  • Redis学习-02安装Redis(Ubuntu版本)、开启远程连接
  • ComfyUI 中RAM内存、VRAM显存、GPU 的占用和原理
  • 基于深度学习的图像识别:从零构建卷积神经网络(CNN)
  • 面对微软AD的安全隐患,宁盾身份域管如何设计安全性
  • Python调用父类方法的三种方式详解 | Python面向对象编程教程
  • 【DOCKER】-5 镜像仓库与容器编排
  • 云服务器如何设置防火墙和安全组规则?
  • Java EE进阶3:SpringBoot 快速上手
  • 【Linux】Makefile(二)-书写规则
  • 【原创】【图像算法】高精密电子仪器组装异常检测
  • 力扣119:杨辉三角Ⅱ
  • Cursor出现This model provider doesn’t serve your region解决方案
  • 【调度算法】
  • javaScript中数组常用的函数方法
  • 洛谷 P1601 A+B Problem(高精)
  • 重构比特币在 Sui DeFi 中的角色
  • Redis中什么是看门狗机制
  • 解决leetcode第3614题用特殊操作处理字符串II
  • 魔术公式轮胎simulink模型建立及参数拟合
  • 实现atm提款简易代码
  • ​​孤儿进程:当父进程先离开时会发生什么?
  • LeetCode|Day17|242. 有效的字母异位词|Python刷题笔记
  • 云服务器的数据如何备份和恢复?
  • Leetcode刷题营第二十八题:二叉树的前序遍历
  • CSS关键字:initial、revert、unset傻傻分不清