PyTorch仿射变换:原理与实战全解析
仿射变换在 PyTorch 中的实现与原理
核心概念
仿射变换是一种保持直线和比例关系的几何变换,由以下操作组成:
- 旋转(Rotation)
- 缩放(Scale)
- 平移(Translation)
- 剪切(Shear)
- 反射(Reflection)
数学形式可表示为齐次坐标下的矩阵运算:
[x'] [a b tx] [x]
[y'] = [c d ty] [y]
[1 ] [0 0 1 ] [1]
其中 a, b, c, d
控制旋转/缩放/剪切,tx, ty
控制平移。
PyTorch 内置实现方式
1. torchvision.transforms.RandomAffine
from torchvision import transformstransform = transforms.Compose([transforms.RandomAffine(degrees=15, # 旋转范围:±15度translate=(0.1, 0.1), # 平移比例:最大10%scale=(0.8, 1.2), # 缩放范围:80%~120%shear=10, # 剪切范围:±10度fill=0 # 空白区域填充值)
])
2. torch.nn.functional.affine_grid
+ grid_sample
底层控制实现:
import torch.nn.functional as Fdef apply_affine(x, matrix):# x: [C, H, W] tensormatrix = matrix[:2, :] # 取前两行grid = F.affine_grid(matrix.unsqueeze(0), x.unsqueeze(0).size())return F.grid_sample(x.unsqueeze(0), grid)[0]# 示例矩阵:旋转30度 + 缩放0.8倍
theta = torch.tensor([[0.8 * math.cos(math.pi/6), -0.8 * math.sin(math.pi/6), 0],[0.8 * math.sin(math.pi/6), 0.8 * math.cos(math.pi/6), 0]
])
transformed = apply_affine(image, theta)
参数详解
参数 | 类型 | 说明 | 计算公式 |
---|---|---|---|
degrees | float/tuple | 旋转角度范围 | 随机角度 ∈ [-deg, +deg] |
translate | tuple | 平移比例 (h, w) | 偏移量 = 图片尺寸 × 随机值 × trans |
scale | tuple | 缩放比例范围 | 随机缩放 ∈ [scale_min, scale_max] |
shear | float/tuple | 剪切角度 | 支持XY方向单独设置 |
fill | int/tuple | 空白区域填充 | 可设统一值或RGB分量值 |
center | tuple | 旋转中心点 | 默认中心点:image_size/2 |
核心实现原理
1. 矩阵合成(顺序敏感)
def get_affine_matrix():# 基础单位矩阵matrix = [[1, 0, 0],[0, 1, 0],[0, 0, 1]]# 应用变换(顺序:缩放→旋转→平移)matrix = compose_matrix(scale_matrix(sx, sy))matrix = compose_matrix(rotation_matrix(theta), matrix)matrix = compose_matrix(translation_matrix(tx, ty), matrix)return matrix[:2] # 返回2x3矩阵
2. 网格采样与双线性插值
def grid_sample(input, grid):for y in range(H):for x in range(W):src_x, src_y = grid[y, x] # 计算原图坐标# 双线性插值(实际使用优化算法)top_left = input[floor(src_y), floor(src_x)]top_right = input[floor(src_y), ceil(src_x)]...output[y,x] = bilinear_interp(top_left, top_right, ...)
3. 边界处理
- 填充处理:
fill
参数设置空白像素值 - 采样模式:
padding_mode='zeros'
(默认)padding_mode='border'
(复制边缘像素)padding_mode='reflection'
(镜像反射)
可视化示例
import matplotlib.pyplot as plt# 原始图像
plt.subplot(231)
plt.imshow(orig_img)
plt.title("Original")# 旋转示例
plt.subplot(232)
rotated = transforms.functional.rotate(tensor_img, 30)
plt.imshow(rotated.permute(1,2,0))
plt.title("Rotation")# 缩放示例
plt.subplot(233)
scaled = transforms.functional.resize(tensor_img, (100,150))
plt.imshow(scaled.permute(1,2,0))
plt.title("Scale")# 剪切示例
sheared = F.affine(tensor_img, angle=0, translate=(0,0),scale=1,shear=(20,0))
plt.subplot(234)
plt.imshow(sheared.permute(1,2,0))
plt.title("Shear")plt.show()
数学基础
1. 基本变换矩阵
变换类型 | 矩阵格式 |
---|---|
平移 | [[1, 0, tx], [0, 1, ty]] |
缩放 | [[sx, 0, 0], [0, sy, 0]] |
旋转 | [[cosθ, -sinθ, 0], [sinθ, cosθ, 0]] |
剪切 | [[1, sh_x, 0], [sh_y, 1, 0]] |
2. 组合变换
M = T \times R \times S \times Sh
其中:
- T:平移矩阵
- R:旋转矩阵
- S:缩放矩阵
- Sh:剪切矩阵
应用场景
-
数据增强:提升模型对几何变换的鲁棒性
train_transform = transforms.Compose([transforms.RandomAffine(degrees=20, shear=10),transforms.ToTensor() ])
-
图像校准:医学影像/卫星图像配准
# 计算最优变换矩阵 optimizer = torch.optim.Adam([matrix_params], lr=0.01) for _ in range(100):warped = F.affine_grid(matrix_params, target.shape)loss = F.mse_loss(warped, target)loss.backward()optimizer.step()
-
空间变换网络 (STN):可学习仿射层
stn = nn.Sequential(nn.Conv2d(1, 10, 5),nn.MaxPool2d(2),nn.ReLU(),nn.Linear(10*12*12, 6) # 输出6个仿射参数 )
性能优化技巧
-
预计算网格:
# 对batch重复使用相同变换时 grid = F.affine_grid(matrix.expand(batch, -1), size)
-
设置梯度计算:
with torch.set_grad_enabled(mode == 'train'):transformed = F.grid_sample(input, grid)
-
插值方式选择:
F.grid_sample(..., mode='bilinear') # 训练常用 F.grid_sample(..., mode='nearest') # 高精度应用
实测速度比较(RTX 3090,256x256图像):
- 单次变换耗时:0.5~1.2 ms
- 双三次插值比双线性慢3倍
- 梯度计算增加15%耗时
与相关技术对比
特性 | 仿射变换 | 透视变换 | 弹性变形 |
---|---|---|---|
自由度 | 6 | 8 | 高 |
保持平行 | ✓ | ✗ | ✗ |
实现难度 | 简单 | 中等 | 复杂 |
PyTorch支持 | 内置 | 需自定义 | GridDistortion |
典型应用 | 几何增强 | 3D投影 | 医学影像 |
仿射变换因其效率与实用性的平衡,成为计算机视觉中应用最广泛的几何变换方法。