深度学习图像预处理可视化:拆解Compose操作的全过程
深度学习图像预处理可视化:拆解Compose操作的全过程
背景需求
在深度学习图像处理中,我们经常使用torchvision.transforms.Compose
或timm
的create_transform
将多个预处理步骤组合成一个流水线。但在实际调试中,开发者常会遇到以下问题:
- 无法直观看到每个变换步骤对图像的具体影响
- 归一化(Normalize)后的张量难以直接可视化
- 随机增强(如翻转、裁剪)导致结果不可复现时难以定位问题
本文目标
通过代码实现以下功能:
- 逐步拆解预处理流水线,记录每个中间步骤的输出
- 自动可视化所有变换结果(包括PIL图像和归一化后的张量)
- 智能布局子图排列,避免空白区域过多
- 动态反归一化处理,还原可读性图像
技术实现亮点
# 关键代码段解析
cols = math.ceil(math.sqrt(total_steps)) # 根据步骤数量动态计算列数
rows = math.ceil(total_steps / cols) # 计算所需行数
# 反归一化处理(以Normalize步骤为例)
if name == "Normalize":
tmp = img_step.permute(1,2,0).numpy() * std + mean # 还原原始像素范围
tmp = np.clip(tmp, 0, 1) # 防止溢出
可视化流程说明
-
输入图像处理
- 原始图像通过
create_transform
定义的多阶段变换 - 包含典型操作:随机翻转(hflip=0.3)、中心裁剪(crop_pct=0.8)、归一化等
- 原始图像通过
-
中间结果捕获
intermediate_img = [img] # 初始化包含原始图像 for i in range(len(transform_list)): intermediate_img.append(transform_list[i](intermediate_img[i]))
- 通过循环逐步应用每个变换并保存结果
-
结果可视化
处理步骤 关键技术 PIL图像显示 直接渲染 Image.Image
对象张量显示 使用 .permute(1,2,0)
调整维度顺序(C×H×W → H×W×C)布局优化 动态计算行列数,保证接近正方形排列(如5个子图显示为2×3网格)
实际应用场景
- 代码演示:直观展示每个预处理步骤的效果
- 算法调试:定位导致图像异常的变换步骤
- 数据增强验证:检查随机裁剪/翻转是否合理
- 模型部署:验证预处理与训练时的一致性
import os
import math
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from timm.data.transforms_factory import create_transform
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# Step1: 读取图片
img = Image.open('DJI_20241009083951.jpg')
# Step2: 创建图像变换组合
transform = create_transform(
input_size=(224, 224),
is_training=True,
hflip=0.3,
vflip = 0.1,
crop_mode='border',
crop_pct=0.8,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
# Step3: 显示变换组合
print(transform)
transform_list = transform.transforms # 将所有变换放到列表中
# Step4: 生成中间结果
intermediate_img = [img]
for i in range(len(transform_list)):
intermediate_img.append(transform_list[i](intermediate_img[i]))
# Step5: 动态计算子图布局
total_steps = len(intermediate_img)
cols = math.ceil(math.sqrt(total_steps)) # 列数
rows = math.ceil(total_steps / cols) # 行数动态计算
# 创建子图(保持二维结构)
fig, axs = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5), squeeze=False)
# Step6: 可视化所有中间结果
for i in range(total_steps):
# 获取当前步骤的名称和图像
name = "Original" if i == 0 else transform_list[i - 1].__class__.__name__
img_step = intermediate_img[i]
# 转换为可显示格式
if isinstance(img_step, Image.Image):
tmp = img_step
elif isinstance(img_step, torch.Tensor):
if name == "Normalize":
# 反归一化处理
mean = np.array(transform_list[i - 1].mean)
std = np.array(transform_list[i - 1].std)
tmp = img_step.permute(1, 2, 0).numpy() * std + mean
tmp = np.clip(tmp, 0, 1)
else:
tmp = img_step.permute(1, 2, 0).numpy()
# 计算子图位置
row = i // cols
col = i % cols
# 绘制子图
axs[row, col].imshow(tmp)
axs[row, col].axis('off')
axs[row, col].set_title(f"Step {i}: {name}", fontsize=8)
# 隐藏多余的空子图
for i in range(total_steps, rows * cols):
row = i // cols
col = i % cols
axs[row, col].axis('off')
plt.tight_layout()
plt.show()