(即插即用模块-Attention部分) 六十五、(2024 WACV) DLKA 可变形大核注意力
文章目录
- 1、Deformable Large Kernel Attention
- 2、代码实现
paper:Beyond Self-Attention: Deformable Large Kernel Attention for Medical Image Segmentation
Code:https://github.com/mindflow-institue/deformableLKA
1、Deformable Large Kernel Attention
Transformer 的局限性: 尽管 Transformer 在捕捉全局信息方面表现出色,但其计算量随 token 数量的平方增长,限制了其深度和分辨率能力。CNN 的局限性: CNN 在提取局部细节方面表现出色,但缺乏捕捉全局信息的机制。而现有的分割方法要么依赖于 CNN 的局部信息提取能力,要么使用 Transformer 的全局信息捕捉能力,缺乏两者之间的平衡。这篇论文在LKA的基础上提出一种 可变形大核注意力(Deformable Large Kernel Attention), D-LKA 模块结合了 LKA 和可变形卷积的优势,能够在保证计算效率的同时,更好地捕捉局部和全局信息。
实现过程:
- 深度可分离卷积:使用深度可分离卷积将特征图分解为通道维度,减少参数量和计算量。
- 深度可分离膨胀卷积:在深度可分离卷积的基础上,进一步增加感受野的大小。
- 1x1 卷积:对得到的特征图进行 1x1 卷积,以调整通道数量。
- 注意力机制:将上述步骤得到的特征图与原始特征图进行点积运算,得到注意力图,表示不同特征之间的相对重要性。
- 输出特征图:将注意力图与原始特征图进行逐元素相乘,并添加残差连接,得到最终的输出特征图。
优势:
- 平衡局部和全局信息: D-LKA 模块能够在保证计算效率的同时,更好地捕捉局部和全局信息,从而实现更准确的分割结果。
- 适应不规则形状和大小: 可变形卷积能够灵活地调整采样网格,从而更好地适应不规则形状和大小的目标。
- 高效计算: D-LKA 模块的计算量远低于 Transformer,能够更有效地处理 3D 数据。
Deformable Large Kernel Attention 结构图:
2、代码实现
import torch
import torch.nn as nn
import torchvisionclass DeformConv(nn.Module):def __init__(self, in_channels, groups, kernel_size=(3, 3), padding=1, stride=1, dilation=1, bias=True):super(DeformConv, self).__init__()self.offset_net = nn.Conv2d(in_channels=in_channels,out_channels=2 * kernel_size[0] * kernel_size[1],kernel_size=kernel_size,padding=padding,stride=stride,dilation=dilation,bias=True)self.deform_conv = torchvision.ops.DeformConv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size,padding=padding,groups=groups,stride=stride,dilation=dilation,bias=False)def forward(self, x):offsets = self.offset_net(x)out = self.deform_conv(x, offsets)return outclass deformable_LKA(nn.Module):def __init__(self, dim):super().__init__()self.conv0 = DeformConv(dim, kernel_size=(5, 5), padding=2, groups=dim)self.conv_spatial = DeformConv(dim, kernel_size=(7, 7), stride=1, padding=9, groups=dim, dilation=3)self.conv1 = nn.Conv2d(dim, dim, 1)def forward(self, x):u = x.clone()attn = self.conv0(x)attn = self.conv_spatial(attn)attn = self.conv1(attn)return u * attnclass deformable_LKA_Attention(nn.Module):def __init__(self, d_model):super().__init__()self.proj_1 = nn.Conv2d(d_model, d_model, 1)self.activation = nn.GELU()self.spatial_gating_unit = deformable_LKA(d_model)self.proj_2 = nn.Conv2d(d_model, d_model, 1)def forward(self, x):shorcut = x.clone()x = self.proj_1(x)x = self.activation(x)x = self.spatial_gating_unit(x)x = self.proj_2(x)x = x + shorcutreturn xif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()model = deformable_LKA_Attention(64).cuda()out = model(x)print(out.shape)