当前位置: 首页 > news >正文

【Block总结】MSC,多尺度稀疏交叉注意力网络在遥感场景分类中的应用|即插即用

📄 论文基本信息

项目内容
标题Multiscale Sparse Cross-Attention Network for Remote Sensing Scene Classification (MSCN)
中文题目多尺度稀疏交叉注意力网络在遥感场景分类中的应用
发表期刊IEEE Transactions on Geoscience and Remote Sensing (TGRS) 2025
所属单位西安电子科技大学人工智能学院、南京理工大学计算机科学与工程学院
核心目标解决遥感图像中因目标尺度多变、小目标易丢失以及背景复杂带来的分类挑战。

💡 论文主要创新点

该论文的创新性主要体现在以下三个层面:

  1. 新颖的网络架构(MSCN):提出了一个多尺度稀疏交叉注意力网络,其核心思想是将高层语义特征与每一个浅层细节特征进行交叉融合,改变了传统方法中简单的特征拼接方式,从而更充分地挖掘不同层级特征中的互补信息。
  2. 核心的注意力机制(MSC):设计了多尺度稀疏交叉注意力(MSC)机制,通过多尺度池化捕捉复杂内容,并创新性地引入TopK稀疏操作过滤掉大量无关信息,从而在提升特征表达能力的同时,增强模型的抗干扰能力。
  3. 特征增强模块(GCE):开发了多组空间通道注意力机制(GCE),该模块通过分组处理并强化关键特征的注意力权重,使得最终用于分类的特征更具判别力。

在这里插入图片描述

🛠️ 方法详解:MSCN如何工作

MSCN的整体框架包含四个主要组件,其工作流程如下:

提取多层级特征
进一步增强特征
输出融合后的分类特征
MSC机制
特征映射为 Q K V
多尺度池化
双重TopK稀疏操作
自适应加权融合
遥感图像输入
多级特征提取模块 MFEM
多级特征交叉融合模块 MFCM
全局上下文增强器 GCE
分类模块 CM
最终分类结果

🔍 核心组件解析

  1. 多级特征提取模块(MFEM):该模块使用卷积神经网络(CNN)作为主干,从输入图像中提取出低、中、高三个层级的特征。其中,浅层特征包含丰富的细节信息(如纹理、边缘),而深层特征则包含更强的语义信息(如物体类别)。

  2. 多级特征交叉融合模块(MFCM):这是论文的核心。该模块通过多尺度稀疏交叉注意力(MSC)机制,将高层特征与每一个浅层特征进行交叉融合。这样做的好处是,能够用深层语义来引导和增强浅层特征,使细节信息也具有类别判别性,同时避免了小目标在深层网络中丢失的问题。

  3. 多尺度稀疏交叉注意力(MSC)机制:其内部操作可分为四步:

    • 多尺度信息提取:对输入特征进行不同窗口大小的平均池化,捕获从局部细节到全局结构的多种尺度信息。
    • 特征映射:将融合后的多尺度特征与另一输入特征分别映射为查询(Q)、键(K)、值(V)矩阵。
    • 稀疏操作:计算注意力矩阵后,使用两次TopK操作进行稀疏化。一次保留较多元素以把握整体关联,一次保留较少元素以聚焦最核心的部分,从而有效过滤掉无关噪声。
    • 自适应融合:通过可学习的参数自适应地融合两种稀疏化后的注意力图,最后与值矩阵加权得到输出。
  4. 全局上下文增强器(GCE)与分类模块(CM):GCE模块对融合后的特征进一步施加注意力,增强关键区域的权重。最后,CM模块执行最终的场景分类。
    在这里插入图片描述

代码

import torch
import torch.nn as nn
from einops import rearrange
from math import sqrtclass MSC(nn.Module):def __init__(self, dim, num_heads=8, topk=True, kernel=[3, 5, 7], s=[1, 1, 1], pad=[1, 2, 3],qkv_bias=False, qk_scale=None, attn_drop_ratio=0., proj_drop_ratio=0., k1=2, k2=3):super(MSC, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop_ratio)self.k1 = k1self.k2 = k2self.attn1 = torch.nn.Parameter(torch.tensor([0.5]), requires_grad=True)self.attn2 = torch.nn.Parameter(torch.tensor([0.5]), requires_grad=True)# self.attn3 = torch.nn.Parameter(torch.tensor([0.3]), requires_grad=True)self.avgpool1 = nn.AvgPool2d(kernel_size=kernel[0], stride=s[0], padding=pad[0])self.avgpool2 = nn.AvgPool2d(kernel_size=kernel[1], stride=s[1], padding=pad[1])self.avgpool3 = nn.AvgPool2d(kernel_size=kernel[2], stride=s[2], padding=pad[2])self.layer_norm = nn.LayerNorm(dim)self.topk = topk  # False Truedef forward(self, x, y):# x0 = xy1 = self.avgpool1(y)y2 = self.avgpool2(y)y3 = self.avgpool3(y)# y = torch.cat([y1.flatten(-2,-1),y2.flatten(-2,-1),y3.flatten(-2,-1)],dim = -1)y = y1 + y2 + y3y = y.flatten(-2, -1)y = y.transpose(1, 2)y = self.layer_norm(y)x = rearrange(x, 'b c h w -> b (h w) c')# y = rearrange(y,'b c h w -> b (h w) c')B, N1, C = y.shape# print(y.shape)kv = self.kv(y).reshape(B, N1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)B, N, C = x.shapeq = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)attn = (q @ k.transpose(-2, -1)) * self.scale# print(self.k1,self.k2)mask1 = torch.zeros(B, self.num_heads, N, N1, device=x.device, requires_grad=False)index = torch.topk(attn, k=int(N1 / self.k1), dim=-1, largest=True)[1]# print(index[0,:,48])mask1.scatter_(-1, index, 1.)attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))attn1 = attn1.softmax(dim=-1)attn1 = self.attn_drop(attn1)out1 = (attn1 @ v)mask2 = torch.zeros(B, self.num_heads, N, N1, device=x.device, requires_grad=False)index = torch.topk(attn, k=int(N1 / self.k2), dim=-1, largest=True)[1]# print(index[0,:,48])mask2.scatter_(-1, index, 1.)attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf')))attn2 = attn2.softmax(dim=-1)attn2 = self.attn_drop(attn2)out2 = (attn2 @ v)out = out1 * self.attn1 + out2 * self.attn2  # + out3 * self.attn3# out = out1 * self.attn1 + out2 * self.attn2x = out.transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)hw = int(sqrt(N))x = rearrange(x, 'b (h w) c -> b c h w', h=hw, w=hw)# x = x + x0return x# 测试代码
if __name__ == '__main__':input = torch.randn(1, 32, 64, 64)  # [B, C=32, H=64, W=64]input1 = torch.randn(1, 32, 64, 64)  # [B, C=32, H=64, W=64]model = MSC(dim=32)output = model(input,input1)print("输入张量形状:", input.shape)print("输出张量形状:", output.shape)

在这里插入图片描述

📊 效果与总结

性能效果

根据论文介绍,MSCN模型在多个公开遥感场景分类数据集上进行了实验,并取得了优异的效果。

  • 有效性:实验结果验证了MSCN模型及其核心组件(MSC机制、GCE模块)能够显著提升遥感场景分类的准确率。
  • 抗干扰性:得益于稀疏注意力机制,模型对复杂背景和无关信息具有更好的鲁棒性。
  • 小目标友好:多尺度交叉融合策略有效缓解了小目标特征在深层网络中被丢失的问题,提升了像“小型飞机”这类目标的分类精度。
总结

总而言之,这篇论文针对遥感场景分类中的核心难题,提出了一个结构新颖、机制有效的解决方案。MSCN网络通过多尺度稀疏交叉注意力成功实现了对多层次特征的高效融合与去噪,在保持计算效率的同时,显著提升了模型的分类性能。这项工作为后续的遥感图像分析研究提供了重要的技术参考和思路启发。

希望以上总结对你有所帮助。如果你对论文中提到的具体实验数据或与其它模型的对比细节感兴趣,我可以尝试帮你进一步寻找相关信息。

http://www.dtcms.com/a/564805.html

相关文章:

  • CANN卷积算子深度优化:以ResNet推理为例
  • PINN求解一维亥姆霍兹方程,以声学问题为例,使用L-BFGS优化器,将传统物理建模与现代深度学习相结合,为解决科学计算问题提供了新的范式,MATLAB代码
  • 文件存储服务有哪些?他们优缺点分别是什么?FastDFS、MinIO、Ceph、HDFS、MooseFS、TFS、七牛云、阿里云 OSS
  • 如何一键将 PDF 转为 Word?
  • 20-递归神经网络(Recursive NN):介绍递归神经网络的特点和用途
  • STM32H743-ARM例程36-DNS
  • 搜集素材的网站吉首建设局网站
  • 16、Docker swarm-3
  • A模块 系统与网络安全 第四门课 弹性交换网络-4
  • Jenkins 在构建 Java 项目并操作 Docker 时 CPU 会突然飙高
  • IP冲突排查方法
  • 爱漫画-只做精品的韩漫网站企业首次建设网站方案流程
  • Fiddler配置方法与使用教程:HTTP/HTTPS抓包分析、代理设置与调试技巧详解(开发者实战指南)
  • UML建模工具Enterprise Architect如何通过威胁建模保障系统安全
  • Android + Flutter打包出来的APK体积太大
  • risc-v vector.S解析
  • QT Quick QML项目音乐播放器15----设置全局背景与颜色、播放页面背景切换
  • 嵌入式系统与RISC-V:开源架构驱动的下一代创新
  • 【Android】Drawable 基础
  • Android和h5页面相互传参
  • 直接做的视频网站怎么用网站模板
  • 通信人C++自学
  • Azure OpenAI 资源监控与成本分析实战指南
  • Notepad++ 安装 json格式化 插件
  • CSS实现居中的8中方法
  • 在阿里云做的网站怎么进后台国际营销信息系统
  • (undone) CFD 学习 (5) 阅读简单求解器理解非结构网格 + 更复杂的 FVM
  • 福州网站建设机构东莞网站推广行者seo08
  • 唐山教育平台网站建设推广的软件有哪些
  • 算法24.0