当前位置: 首页 > news >正文

每日Attention学习25——Multi-Scale Attention Fusion

模块出处

[TCSVT 24] [link] [code] DSNet: A Novel Way to Use Atrous Convolutions in Semantic Segmentation


模块名称

Multi-Scale Attention Fusion (MSAF)


模块作用

双级特征融合


模块结构

在这里插入图片描述


模块思想

MSAF的主要思想是让网络根据损失学习特征权重,允许模型选择性地融合来自不同尺度的信息。


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class MSAF(nn.Module):
    def __init__(self, channels=64, r=4):
        super(MSAF, self).__init__()
        inter_channels = int(channels // r)

        self.local_att = nn.Sequential(
            nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels),
        )

        self.context1 = nn.Sequential(
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels)
        )

        self.context2 = nn.Sequential(
            nn.AdaptiveAvgPool2d((8, 8)),
            nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels)
        )

        self.global_att = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels),
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x, residual):
        h, w = x.shape[2], x.shape[3]
        xa = x + residual
        xl = self.local_att(xa)
        c1 = self.context1(xa)
        c2 = self.context2(xa)
        xg = self.global_att(xa)
        c1 = F.interpolate(c1, size=[h, w], mode='nearest')
        c2 = F.interpolate(c2, size=[h, w], mode='nearest')
        xlg = xl + xg + c1 + c2 
        wei = self.sigmoid(xlg)
        xo = 2 * x * wei + 2 * residual * (1 - wei)
        return xo
    

if __name__ == '__main__':
    msaf = MSAF()
    x1 = torch.randn([2, 64, 16, 16])
    x2 = torch.randn([2, 64, 16, 16])
    out = msaf(x1, x2)  
    print(out.shape)  # 2, 64, 16, 16

http://www.dtcms.com/a/73119.html

相关文章:

  • 结构体1~5(1414. 期末考试成绩排名、1490. 坐标排序、1315. 遥控飞机争夺赛、1730. 购买贺年卡、1499. 宇宙总统2)
  • Windows Qt动态监测系统分辨率及缩放比变化
  • LGA封装 Z3588开发板,8K视频编解码
  • 设计模式使用Java案例
  • 《AI大模型趣味实战》No2 : 快速搭建一个漂亮的AI家庭网站-相册/时间线/日历/多用户/个性化配色(中)
  • Leetcode-131.Palindrome Partitioning [C++][Java]
  • RUOYI框架在实际项目中的应用三:Ruoyi微服务版本-RuoYi-Cloud
  • JAVA数据库技术(一)
  • Deepseek学习--工具篇之Ollama
  • 基于C#的以太网通讯实现:TcpClient异步通讯详解
  • 设置echarts legend 图例与文字对齐
  • 股指期货有卖不出去的时候吗?
  • 在线 SQL 转 flask SQLAlchemy 模型
  • ctf web入门知识合集
  • 阿里wan2.1本地部署
  • Webpack总结
  • MySQL配置文件my.cnf详解
  • 抽象工厂模式 (Abstract Factory Pattern)
  • 蓝桥杯专项复习——结构体、输入输出
  • 花生好车:重构汽车新零售生态的破局者
  • HTML5前端第三章节
  • Centos离线安装openssl-devel
  • 【深度学习与大模型基础】第5章-线性相关与生成子空间
  • 音视频缓存数学模型
  • AI-医学影像分割方法与流程
  • Spring Validation参数校验
  • P1118 [USACO06FEB] Backward Digit Sums G/S
  • 前端项目的构建流程无缝集成到 Maven 生态系统(一)
  • C Sharp 集合
  • 包装类简单认识泛型