当前位置: 首页 > news >正文

(即插即用模块-Attention部分) 六十三、(2024 CVPR) MLKA 多尺度大核注意力

在这里插入图片描述

文章目录

  • 1、Multi-scale Large Kernel Attention
  • 2、代码实现

paper:MULTI-SCALE ATTENTION NETWORK FOR SINGLE IMAGE SUPER-RESOLUTION

Code:https://github.com/icandle/MAN


1、Multi-scale Large Kernel Attention

为了解决如何有效地建立不同区域之间的长距离相关性,并避免由于大卷积核带来的“块效应”问题。这篇论文在 LKA 的基础上提出了一种 多尺度大核注意力(Multi-scale Large Kernel Attention),MLKA 的设计动机是为了解决图像超分辨率任务中,MLKA 结合了 大卷积核分解 和 多尺度机制 来实现这一目标。

MLKA 的实现过程:

  1. 输入特征图 X: 输入特征图 X 被分解成多个组,每个组包含相同数量的通道。
  2. LKA 模块: 对每个组应用 LKA 模块,生成不同尺度上的注意力图 LKAi。
  3. 门控模块: 为了避免扩张卷积带来的“块效应”,对每个组生成的注意力图进行动态重校准。这样可以更好地保留局部纹理信息。通过对每个 LKAi 应用门控模块,生成门控注意力图 MLKAi。
  4. 聚合: 将所有 MLKAi 聚合,得到最终的注意力图。

MLKA 的优势:

  • 更全面的长距离相关性学习: 通过多尺度机制,MLKA 可以学习不同尺度上的长距离相关性,从而更好地恢复图像细节。
  • 避免“块效应”: 通过门控机制,MLKA 可以有效地避免扩张卷积带来的“块效应”,从而更好地保留图像的平滑性。
  • 计算效率高: MLKA 通过大卷积核分解和门控机制,实现了计算效率的提升。

Multi-scale Large Kernel Attention 结构图:
在这里插入图片描述


2、代码实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LayerNorm(nn.Module):def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedErrorself.normalized_shape = (normalized_shape,)def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return xclass MLKA(nn.Module):def __init__(self, n_feats, k=2, squeeze_factor=15):super().__init__()i_feats = 2 * n_featsself.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)# Multiscale Large Kernel Attentionself.LKA7 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA5 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA3 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)self.proj_first = nn.Sequential(nn.Conv2d(n_feats, i_feats, 1, 1, 0))self.proj_last = nn.Sequential(nn.Conv2d(n_feats, n_feats, 1, 1, 0))def forward(self, x, pre_attn=None, RAA=None):shortcut = x.clone()x = self.norm(x)x = self.proj_first(x)a, x = torch.chunk(x, 2, dim=1)a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],dim=1)x = self.proj_last(x * a) * self.scale + shortcutreturn xif __name__ == '__main__':x = torch.randn(4, 360, 64, 64).cuda()model = MLKA(360).cuda()out = model(x)print(out.shape)

相关文章:

  • 我写了一个分析 Linux 平台打开文件描述符跨进程传递的工具
  • 学习黑客网络安全法
  • Docker与WSL2如何清理
  • WebRTC 服务器之Janus架构分析
  • 【JAVA】数组与内存模型:二维数组底层实现(9)
  • 2.2 矩阵
  • NV203NV207SSD固态闪存NV208NV213
  • Maven 实现多模块项目依赖管理
  • neo4j初尝试
  • YOLOv11改进:利用RT-DETR主干网络PPHGNetV2助力轻量化目标检测
  • Excel-CLI:终端中的轻量级Excel查看器
  • 普通IT的股票交易成长史--20250502 突破(2)
  • 硬件工程师面试常见问题(12)
  • ES6/ES11知识点 续一
  • JavaScript性能优化实战之调试与性能检测工具
  • 【Hive入门】Hive与Spark SQL深度集成:通过Spark ThriftServer高效查询Hive表
  • 【Hive入门】Hive与Spark SQL深度集成:执行引擎性能全面对比与调优分析
  • 【Linux】Petalinux驱动开发基础
  • 学习黑客安全基础理论入门
  • Vue3源码学习6-增强鲁棒性?
  • 央行宣布优化两项支持资本市场的货币政策工具
  • 魔都眼|上海环球马术冠军赛收官,英国骑手夺冠
  • 美国得克萨斯州发生5.4级地震,震源深度10千米
  • 老人误操作免密支付买几百只鸡崽,经济日报:支付要便捷也要安全
  • “五一”假期第三天,预计全社会跨区域人员流动量超2.8亿人次
  • 五一假期首日,上海外滩客流超55万人次