当前位置: 首页 > news >正文

(即插即用模块-Attention部分) 六十五、(2024 WACV) DLKA 可变形大核注意力

在这里插入图片描述

文章目录

  • 1、Deformable Large Kernel Attention
  • 2、代码实现

paper:Beyond Self-Attention: Deformable Large Kernel Attention for Medical Image Segmentation

Code:https://github.com/mindflow-institue/deformableLKA


1、Deformable Large Kernel Attention

Transformer 的局限性: 尽管 Transformer 在捕捉全局信息方面表现出色,但其计算量随 token 数量的平方增长,限制了其深度和分辨率能力。CNN 的局限性: CNN 在提取局部细节方面表现出色,但缺乏捕捉全局信息的机制。而现有的分割方法要么依赖于 CNN 的局部信息提取能力,要么使用 Transformer 的全局信息捕捉能力,缺乏两者之间的平衡。这篇论文在LKA的基础上提出一种 可变形大核注意力(Deformable Large Kernel Attention), D-LKA 模块结合了 LKA 和可变形卷积的优势,能够在保证计算效率的同时,更好地捕捉局部和全局信息。

实现过程:

  1. 深度可分离卷积:使用深度可分离卷积将特征图分解为通道维度,减少参数量和计算量。
  2. 深度可分离膨胀卷积:在深度可分离卷积的基础上,进一步增加感受野的大小。
  3. 1x1 卷积:对得到的特征图进行 1x1 卷积,以调整通道数量。
  4. 注意力机制:将上述步骤得到的特征图与原始特征图进行点积运算,得到注意力图,表示不同特征之间的相对重要性。
  5. 输出特征图:将注意力图与原始特征图进行逐元素相乘,并添加残差连接,得到最终的输出特征图。

优势:

  • 平衡局部和全局信息: D-LKA 模块能够在保证计算效率的同时,更好地捕捉局部和全局信息,从而实现更准确的分割结果。
  • 适应不规则形状和大小: 可变形卷积能够灵活地调整采样网格,从而更好地适应不规则形状和大小的目标。
  • 高效计算: D-LKA 模块的计算量远低于 Transformer,能够更有效地处理 3D 数据。

Deformable Large Kernel Attention 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torchvisionclass DeformConv(nn.Module):def __init__(self, in_channels, groups, kernel_size=(3, 3), padding=1, stride=1, dilation=1, bias=True):super(DeformConv, self).__init__()self.offset_net = nn.Conv2d(in_channels=in_channels,out_channels=2 * kernel_size[0] * kernel_size[1],kernel_size=kernel_size,padding=padding,stride=stride,dilation=dilation,bias=True)self.deform_conv = torchvision.ops.DeformConv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size,padding=padding,groups=groups,stride=stride,dilation=dilation,bias=False)def forward(self, x):offsets = self.offset_net(x)out = self.deform_conv(x, offsets)return outclass deformable_LKA(nn.Module):def __init__(self, dim):super().__init__()self.conv0 = DeformConv(dim, kernel_size=(5, 5), padding=2, groups=dim)self.conv_spatial = DeformConv(dim, kernel_size=(7, 7), stride=1, padding=9, groups=dim, dilation=3)self.conv1 = nn.Conv2d(dim, dim, 1)def forward(self, x):u = x.clone()attn = self.conv0(x)attn = self.conv_spatial(attn)attn = self.conv1(attn)return u * attnclass deformable_LKA_Attention(nn.Module):def __init__(self, d_model):super().__init__()self.proj_1 = nn.Conv2d(d_model, d_model, 1)self.activation = nn.GELU()self.spatial_gating_unit = deformable_LKA(d_model)self.proj_2 = nn.Conv2d(d_model, d_model, 1)def forward(self, x):shorcut = x.clone()x = self.proj_1(x)x = self.activation(x)x = self.spatial_gating_unit(x)x = self.proj_2(x)x = x + shorcutreturn xif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()model = deformable_LKA_Attention(64).cuda()out = model(x)print(out.shape)
http://www.dtcms.com/a/172297.html

相关文章:

  • 方法:批量识别图片区域文字并重命名,批量识别指定区域内容改名,基于QT和阿里云的实现方案,详细方法
  • GGD独立站的优势
  • 如何判断cgroup的版本?
  • 【PostgreSQL数据分析实战:从数据清洗到可视化全流程】4.3 数据脱敏与安全(模糊处理/掩码技术)
  • SpringBoot实战:整合Knife4j
  • 前端懒加载(Lazy Loading)实战指南
  • 开元类双端互动组件部署实战全流程教程(第3部分:UI资源加载机制与界面逻辑全面解析
  • 金仓数据库 KingbaseES 在电商平台数据库迁移与运维中深入复现剖析
  • C++和Lua混和调用
  • 编译原理期末重点-个人总结——2 文法与语言
  • 相同IP和端口的服务器ssh连接时出现异常
  • 36-校园反诈系统(小程序)
  • JS DAY4 日期对象与节点
  • JAVA简单走进AI世界~Spring AI
  • Ubuntu K8S(1.28.2) 节点/etc/kubernetes/manifests 不存在
  • 二、【LLaMA-Factory实战】数据工程全流程:从格式规范到高质量数据集构建
  • 虚幻引擎5-Unreal Engine笔记之显卡环境设置使开发流畅
  • springboot+mysql+element-plus+vue完整实现汽车租赁系统
  • Vue3携手Echarts,打造炫酷数据可视化大屏
  • Flutter——数据库Drift开发详细教程(四)
  • GZ人博会自然资源系统(测绘)备考笔记
  • 享元模式(Flyweight Pattern)详解
  • 小米刷新率 2.4 | 突破屏幕刷新率限制,享受更流畅视觉体验的应用程序
  • 内存碎片深度剖析
  • 十大排序算法全面解析(Java实现)及优化策略
  • Java SE(8)——继承
  • 残差网络实战:基于MNIST数据集的手写数字识别
  • 主机漏洞扫描:如何保障网络安全及扫描原理与类型介绍?
  • JVM 内存结构全解析
  • 【NLP】32. Transformers (HuggingFace Pipelines 实战)