(即插即用模块-特征处理部分) 四十一、(2024) MSAA 多尺度注意力聚合模块
文章目录
- 1、Multi-Scale Attention Aggregation Module
- 2、代码实现
paper:CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation
Code:https://github.com/XiaoBuL/CM-UNet
1、Multi-Scale Attention Aggregation Module
传统跳连接的局限性: UNet 等模型中常用的跳连接方式,将低层特征图直接加到高层特征图上,无法有效融合不同尺度的信息,限制了分割效果的提升。提取多尺度信息的重要性: 遥感图像分割需要考虑不同尺度的信息,例如大尺度信息可以提供全局场景理解,小尺度信息可以捕捉细节特征。所以,这篇论文提出一种 多尺度注意力聚合模块(Multi-Scale Attention Aggregation Module),MSAA 模块旨在通过聚合不同尺度的特征,从而更好地捕捉多尺度上下文信息。
实现过程:
- 特征融合: 将编码器不同阶段的特征图进行融合,形成新的特征图 ˆF。
- 空间路径:使用 1x1 卷积将特征图的通道数减少,降低维度。然后对不同核大小的卷积结果进行求和,例如 3x3, 5x5, 7x7 卷积,以融合不同尺度的特征。最后使用平均池化和最大池化对空间特征进行聚合,并通过 7x7 卷积和 sigmoid 激活函数进行非线性变换。
- 通道路径: 将特征图的空间维度压缩为 1x1,并进行全局平均池化,提取全局信息。 使用 1x1 卷积和 ReLU 激活函数生成通道注意力图,并扩大其尺寸以匹配输入特征图的维度。
- 特征整合: 将空间路径和通道路径的结果进行元素级相加,得到最终的输出特征图。
优势:
- 有效地融合多尺度信息: MSAA 模块可以有效地融合不同尺度的特征,从而更好地捕捉多尺度上下文信息,提升分割精度。
- 增强分割细节: MSAA 模块可以增强分割细节,例如建筑物边缘、道路等,从而提高分割结果的清晰度。
Multi-Scale Attention Aggregation 结构图:
2、代码实现
import torch
import torch.nn as nnclass ChannelAttentionModule(nn.Module):def __init__(self, in_channels, reduction=4):super(ChannelAttentionModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttentionModule(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttentionModule, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)class MSAA(nn.Module):def __init__(self, in_channels, out_channels, factor=4.0):super(MSAA, self).__init__()dim = int(out_channels // factor)self.down = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)self.conv_3x3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)self.conv_5x5 = nn.Conv2d(dim, dim, kernel_size=5, stride=1, padding=2)self.conv_7x7 = nn.Conv2d(dim, dim, kernel_size=7, stride=1, padding=3)self.spatial_attention = SpatialAttentionModule()self.channel_attention = ChannelAttentionModule(dim)self.up = nn.Conv2d(dim, out_channels, kernel_size=1, stride=1)def forward(self, x1, x2, x4):# # x2 是从低到高,x4是从高到低的设计,x2传递语义信息,x4传递边缘问题特征补充# x_1_2_fusion = self.fusion_1x2(x1, x2)# x_1_4_fusion = self.fusion_1x4(x1, x4)# x_fused = x_1_2_fusion + x_1_4_fusionx_fused = torch.cat([x1, x2, x4], dim=1)x_fused = self.down(x_fused)x_fused_c = x_fused * self.channel_attention(x_fused)x_3x3 = self.conv_3x3(x_fused)x_5x5 = self.conv_5x5(x_fused)x_7x7 = self.conv_7x7(x_fused)x_fused_s = x_3x3 + x_5x5 + x_7x7x_fused_s = x_fused_s * self.spatial_attention(x_fused_s)x_out = self.up(x_fused_s + x_fused_c)return x_outif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()y = torch.randn(4, 64, 128, 128).cuda()z = torch.randn(4, 64, 128, 128).cuda()model = MSAA(192, 64).cuda()out = model(x, y, z)print(out.shape)