当前位置: 首页 > news >正文

【IQA技术专题】DISTS代码讲解

本文是对DISTS图像质量评价指标的代码解读,原文解读请看DISTS文章讲解。
本文的代码来源于IQA-Pytorch工程。

1、原文概要

以前的一些IQA方法对于捕捉纹理上的感知一致性有所欠缺,鲁棒性不足。基于此,作者开发了一个能够在图像结构和图像纹理上都具有与人类相同感知判断的指标,在此之上,还希望纹理能够resample(不需要像素级对齐)之后也是一样的,另外区分开退化(JPEG,JPEG会损失纹理)。实现该指标可以分为4个步骤:

  1. 对图像进行一个初始的变换,从像素空间变换到特征空间。
  2. 对特征提取所谓纹理的表示,对特征提取所谓结构的表示。
  3. 利用纹理和结构的表示,加入一些可学习的权重综合计算一个评价指标。
  4. 利用这个评价指标,进一步优化权重得到纹理区域resample不敏感的指标,且能够有结构和纹理上做感知相似度的模型。

实现后的指标作为优化指标对比其他IQA指标有明显优势,如下图所示。
在这里插入图片描述

2、代码结构

代码实现位于pyiqa/archs/dists_arch.py中
在这里插入图片描述

3 、核心代码模块

L2pooling

这个类实现了我们前面提到的预处理部分替换max-pool的操作。

class L2pooling(nn.Module):def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):super(L2pooling, self).__init__()self.padding = (filter_size - 2) // 2self.stride = strideself.channels = channelsa = np.hanning(filter_size)[1:-1]g = torch.Tensor(a[:, None] * a[None, :])g = g / torch.sum(g)self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1)))def forward(self, input):input = input**2out = F.conv2d(input,self.filter,stride=self.stride,padding=self.padding,groups=input.shape[1],)return (out + 1e-12).sqrt()

这里可以看到前向的过程中作者先是进行了一个平方,然后使用了一个self.filter的滤波器,kernel_size为3的hanning窗,stride=2,且是一个深度可分离的卷积,groups与输入通道一致,这代替max-pool完成了一次抗混叠的下采样,最后进行一个sqrt,这与讲解中展示的公式一致,如下所示:
P(x)=g∗(x∗x)P(x)=\sqrt{g*(x*x)}P(x)=g(xx)这个ggg在初始化时被复制了self.channels次,实际它一个通道的数值,读者可以打印如下所示:
[0.06250.1250.06250.1250.250.1250.06250.1250.0625]\begin{bmatrix} 0.0625 & 0.125 & 0.0625 \\ 0.125 & 0.25 & 0.125 \\ 0.0625 & 0.125 & 0.0625 \end{bmatrix} 0.06250.1250.06250.1250.250.1250.06250.1250.0625一个典型的低通滤波器,做了一个空间上根据距离的平均。

DISTS

存放着跟实际计算指标相关的代码。

@ARCH_REGISTRY.register()
class DISTS(torch.nn.Module):r"""DISTS model.Args:pretrained_model_path (String): Pretrained model path."""def __init__(self, pretrained=True, pretrained_model_path=None, **kwargs):"""Refer to official code https://github.com/dingkeyan93/DISTS"""super(DISTS, self).__init__()vgg_pretrained_features = models.vgg16(weights='IMAGENET1K_V1').featuresself.stage1 = torch.nn.Sequential()self.stage2 = torch.nn.Sequential()self.stage3 = torch.nn.Sequential()self.stage4 = torch.nn.Sequential()self.stage5 = torch.nn.Sequential()for x in range(0, 4):self.stage1.add_module(str(x), vgg_pretrained_features[x])self.stage2.add_module(str(4), L2pooling(channels=64))for x in range(5, 9):self.stage2.add_module(str(x), vgg_pretrained_features[x])self.stage3.add_module(str(9), L2pooling(channels=128))for x in range(10, 16):self.stage3.add_module(str(x), vgg_pretrained_features[x])self.stage4.add_module(str(16), L2pooling(channels=256))for x in range(17, 23):self.stage4.add_module(str(x), vgg_pretrained_features[x])self.stage5.add_module(str(23), L2pooling(channels=512))for x in range(24, 30):self.stage5.add_module(str(x), vgg_pretrained_features[x])for param in self.parameters():param.requires_grad = Falseself.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))self.chns = [3, 64, 128, 256, 512, 512]self.register_parameter('alpha', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))self.register_parameter('beta', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))self.alpha.data.normal_(0.1, 0.01)self.beta.data.normal_(0.1, 0.01)if pretrained_model_path is not None:load_pretrained_network(self, pretrained_model_path, False)elif pretrained:load_pretrained_network(self, default_model_urls['url'], False)def forward_once(self, x):h = (x - self.mean) / self.stdh = self.stage1(h)h_relu1_2 = hh = self.stage2(h)h_relu2_2 = hh = self.stage3(h)h_relu3_3 = hh = self.stage4(h)h_relu4_3 = hh = self.stage5(h)h_relu5_3 = hreturn [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]def forward(self, x, y):r"""Compute IQA using DISTS model.Args:- x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images.- y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images.Returns:Value of DISTS model."""feats0 = self.forward_once(x)feats1 = self.forward_once(y)dist1 = 0dist2 = 0c1 = 1e-6c2 = 1e-6w_sum = self.alpha.sum() + self.beta.sum()alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)beta = torch.split(self.beta / w_sum, self.chns, dim=1)for k in range(len(self.chns)):x_mean = feats0[k].mean([2, 3], keepdim=True)y_mean = feats1[k].mean([2, 3], keepdim=True)S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_meanS2 = (2 * xy_cov + c2) / (x_var + y_var + c2)dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)score = 1 - (dist1 + dist2)return score.squeeze(-1).squeeze(-1)

3个重点如下:

  1. 初始化中首先会插入前面讲到的L2_Pooling,来替换原始的max-pool,其他的就是初始化必要的标准化变量和用于各层结构和纹理的加权系数α\alphaαβ\betaβ,最后导入预训练的网络即可。
  2. 前向中调用的forward_once,可以看到总共有6个输出,第一个输出是输入x,即我们讲解中提到的identity的变换,其他5层是事先定义好的输出位置。
  3. dists的计算:首先根据权重的大小对alpha和beta进行归一化,随后分层计算我们前面定义好的纹理特征和结构特征的相关性公式,针对于纹理的部分代码中是S1,可以看到S1是利用了特征的在空间上的均值计算的参考图像和待评估图像的相关系数,然后利用alpha对计算好的S1进行加权,得到纹理上相似度dist1;针对于结构的部分代码中是S2,S2是利用了参考图像和待评估图像两个特征的协方差和方差,由于是全局的窗口所以在计算后会求取空间上的一个均值,这样得到了结构上的相似度dist2。最后结合dist1和dist2得到最终的score。dists计算的公式如下,可以对照着公式来查看:
    l(x~j(i),y~j(i))=2μx~j(i)μy~j(i)+c1(μx~j(i))2+(μy~j(i))2+c1l(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) = \frac{2\mu_{\tilde{x}_j}^{(i)}\mu_{\tilde{y}_j}^{(i)} + c_1}{(\mu_{\tilde{x}_j}^{(i)})^2 + (\mu_{\tilde{y}_j}^{(i)})^2 + c_1}l(x~j(i),y~j(i))=(μx~j(i))2+(μy~j(i))2+c12μx~j(i)μy~j(i)+c1 s(x~j(i),y~j(i))=2σx~jy~j(i)+c2(σx~j(i))2+(σy~j(i))2+c2,s(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) = \frac{2\sigma_{\tilde{x}_j\tilde{y}_j}^{(i)} + c_2}{(\sigma_{\tilde{x}_j}^{(i)})^2 + (\sigma_{\tilde{y}_j}^{(i)})^2 + c_2},s(x~j(i),y~j(i))=(σx~j(i))2+(σy~j(i))2+c22σx~jy~j(i)+c2, D(x,y;α,β)=1−∑i=0m∑j=1ni(αijl(x~j(i),y~j(i))+βijs(x~j(i),y~j(i)))D(x, y; \alpha, \beta) = 1 - \sum_{i = 0}^{m} \sum_{j = 1}^{n_i} \left( \alpha_{ij} l(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) + \beta_{ij} s(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) \right)D(x,y;α,β)=1i=0mj=1ni(αijl(x~j(i),y~j(i))+βijs(x~j(i),y~j(i)))其中,lllsss分别代表纹理和结构。

3、总结

代码实现核心的部分讲解完毕,DISTS作为一个可以同时捕获结构和纹理相似度的全参考IQA指标,在很多比赛和论文的引用中都可以见到它的身影,实用性是毋庸置疑的。
大家有涉及到数据集筛选、纹理分类、纹理搜索类的任务可以尝试使用DISTS指标,或者是在算法评估中利用它来做一个方面的对比评估。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

http://www.dtcms.com/a/308154.html

相关文章:

  • 深入剖析:C++ 手写实现 unordered_map 与 unordered_set 全流程指南
  • Qt 如何从 .ts 文件提取所有源文
  • 2024年SEVC SCI2区,一致性虚拟领航者跟踪群集算法GDRRT*-PSO+多无人机路径规划,深度解析+性能实测
  • TDengine 中 TDgp 中添加算法模型(异常检测)
  • 【生活篇】Ubuntu22.04安装网易云客户端
  • 河南萌新联赛2025第(三)场:河南理工大学(补题)
  • .NET 10 中的新增功能系列文章3—— .NET MAUI 中的新增功能
  • gen_compile_commands.sh
  • elk部署加日志收集
  • 网络爬虫(python)入门
  • webpack-babel
  • 开发避坑短篇(11):Oracle DATE(7)到MySQL时间类型精度冲突解决方案
  • uniapp x swiper/image组件mode=“aspectFit“ 图片有的闪现后黑屏
  • Vue多请求并行处理实战指南
  • 【qiankun】基于vite的qiankun微前端框架下,子应用的静态资源无法加载的问题
  • [硬件电路-111]:滤波的分类:模拟滤波与数字滤波; 无源滤波与有源滤波;低通、带通、带阻、高通滤波;时域滤波与频域滤波;低价滤波与高阶滤波。
  • 2025做美业还有前景吗?博弈美业系统带来美业市场分析
  • rustdesk 1.4.1版本全解析:新增功能、性能优化与多平台支持详解
  • 【机器学习】KNN算法与模型评估调优
  • 深度学习批量矩阵乘法实战解析:torch.bmm
  • 【科普】在STM32中有哪些定时器?
  • 【Golang】用官方rate包构造简单IP限流器
  • 【STM32】HAL库中的实现(二):串口(USART)/看门狗(IWDG/WWDG)/定时器(TIM)
  • 三格——环网式CAN光纤中继器进行光纤冗余环网组网测试
  • 工业绝缘监测仪:保障工业电气安全的关键防线
  • C# 枚举器和迭代器(常见迭代器模式)
  • 26考研|数学分析:重积分
  • ubuntu24.04环境下树莓派Pico C/C++ SDK开发环境折腾记录
  • 设计模式:命令模式 Command
  • AI驱动下的数据新基建:腾讯游戏数据资产治理与湖仓架构革新