当前位置: 首页 > news >正文

图像增广——弹性形变

一、非Pytorch版(计算量较小)

import torch
import numpy as np
from scipy.ndimage import map_coordinates, gaussian_filter
from scipy import interpolate
import matplotlib.pyplot as pltclass ElasticTransform:def __init__(self, alpha=50, sigma=5, p=0.5):"""Args:alpha: 形变强度因子,控制形变幅度sigma: 高斯滤波的标准差,控制形变平滑度p: 应用变换的概率"""self.alpha = alphaself.sigma = sigmaself.p = pdef __call__(self, img):if torch.rand(1).item() > self.p:return img# 转换为numpy数组进行处理if torch.is_tensor(img):img_np = img.numpy()is_tensor = Trueelse:img_np = np.array(img)is_tensor = False# 获取图像尺寸if img_np.ndim == 3:  # (C, H, W)c, h, w = img_np.shapeelse:  # (H, W)h, w = img_np.shapec = 1img_np = img_np[np.newaxis, :, :]# 生成随机位移场dx = np.random.uniform(-1, 1, (h, w)) * self.alphady = np.random.uniform(-1, 1, (h, w)) * self.alpha# 对位移场进行高斯平滑dx = gaussian_filter(dx, sigma=self.sigma, mode='constant')dy = gaussian_filter(dy, sigma=self.sigma, mode='constant')# 创建坐标网格x, y = np.meshgrid(np.arange(w), np.arange(h))# 应用位移indices_x = np.reshape(x + dx, (-1, 1))indices_y = np.reshape(y + dy, (-1, 1))# 对每个通道应用弹性变换transformed_channels = []for channel in range(c):transformed_channel = map_coordinates(img_np[channel], [indices_y.ravel(), indices_x.ravel()], order=3,  # 三次样条插值mode='reflect').reshape(h, w)transformed_channels.append(transformed_channel)transformed_img = np.stack(transformed_channels)# 恢复原始形状if c == 1:transformed_img = transformed_img[0]# 转换回tensorif is_tensor:transformed_img = torch.from_numpy(transformed_img).float()return transformed_img

二、Pytorch版(计算量较大)

import torch
import torch.nn.functional as Fclass ElasticTransformTorch:def __init__(self, alpha=50, sigma=5, p=0.5):self.alpha = alphaself.sigma = sigmaself.p = pdef __call__(self, img):if torch.rand(1).item() > self.p:return img# 确保是tensorif not torch.is_tensor(img):img = torch.from_numpy(np.array(img)).float()# 获取设备信息device = img.device# 获取图像尺寸if img.dim() == 3:  # (C, H, W)c, h, w = img.shapeelse:  # (H, W)h, w = img.shapec = 1img = img.unsqueeze(0)# 生成随机位移场dx = torch.randn(h, w, device=device) * self.alphady = torch.randn(h, w, device=device) * self.alpha# 使用高斯滤波平滑位移场kernel_size = int(6 * self.sigma) + 1dx = self._gaussian_filter(dx, kernel_size, self.sigma)dy = self._gaussian_filter(dy, kernel_size, self.sigma)# 创建标准化网格 (-1 到 1)y_grid, x_grid = torch.meshgrid(torch.linspace(-1, 1, h, device=device),torch.linspace(-1, 1, w, device=device),indexing='ij')# 应用位移 (归一化到 [-1, 1] 范围)grid_x = x_grid + (2 * dx / w)grid_y = y_grid + (2 * dy / h)# 组合网格grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)  # (1, H, W, 2)# 对每个通道应用网格采样transformed_channels = []for channel in range(c):channel_img = img[channel:channel+1].unsqueeze(0)  # (1, 1, H, W)transformed_channel = F.grid_sample(channel_img, grid, mode='bilinear', padding_mode='reflection',align_corners=True)transformed_channels.append(transformed_channel.squeeze())transformed_img = torch.stack(transformed_channels)# 恢复原始形状if c == 1:transformed_img = transformed_img[0]return transformed_imgdef _gaussian_filter(self, x, kernel_size, sigma):"""应用高斯滤波"""# 创建高斯核coords = torch.arange(kernel_size, device=x.device) - kernel_size // 2g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))g = g / g.sum()# 应用分离的高斯滤波x = x.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)# 水平方向padding = kernel_size // 2x = F.conv2d(x, g.view(1, 1, 1, -1), padding=(0, padding))# 垂直方向x = F.conv2d(x, g.view(1, 1, -1, 1), padding=(padding, 0))return x.squeeze()

http://www.dtcms.com/a/512490.html

相关文章:

  • 视频推拉流平台EasyDSS技术特点解析及多元应用场景剖析
  • 做网站需要学php吗北京公司注册代理
  • 职高门户网站建设标准wordpress火车头发布模板
  • CycleGAN实现MNIST与SVHN风格迁移
  • AVL树手撕,超详细图文详解
  • ZeroTier虚拟局域网内搭建DNS服务器
  • 网络与通信安全课程复习汇总3——身份认证
  • 诸城网站做的好的创网站 灵感
  • C++多线程、STL
  • 自己做的网站怎么加入微信支付哪个网站做五金冲压的
  • MySQL数据库05:DQL查询运算符
  • 橙米网站建设网站建设合同制人员招聘
  • 织梦网站图片修改文化墙 北京广告公司
  • VTK——双重深度剥离
  • Linux小课堂: 软件安装与源码编译实战之从 RPM 到源码构建的完整流程
  • 【Python编程】之面向对象
  • Day67 Linux I²C 总线与设备驱动架构、开发流程与调试
  • 【AI增强质量管理体系结构】AI+自动化测试引擎 与Coze
  • 音频共享耳机专利拆解:碰击惯性数据监测与阈值减速识别机制研究
  • 青岛专业网站设计公司网站后台程序怎么做
  • MySQL创建用户、权限分配以及添加、修改权限
  • 【循环神经网络基础】
  • 郑州网站建设与设计校园网站建设年度总结
  • 中国新冠一共死去的人数网站优化和提升网站排名怎么做
  • 太仓手机网站建设阿里云如何做网站
  • 第二篇:按键交互入门:STM32 GPIO输入与消抖处理
  • JSP九大内置对象
  • 如何选择大良网站建设网站建设插件代码大全
  • 卡码网语言基础课(Python) | 17.判断集合成员
  • 温州专业网站建设成都营销推广公司