当前位置: 首页 > news >正文

每日Attention学习26——Dynamic Weighted Feature Fusion

模块出处

[ACM MM 23] [link] [code] Efficient Parallel Multi-Scale Detail and Semantic Encoding
Network for Lightweight Semantic Segmentation


模块名称

Dynamic Weighted Feature Fusion (DWFF)


模块作用

双级特征融合


模块结构

在这里插入图片描述


模块思想

我们提出了 DWFF 策略,选择性地关注特征图中信息量最大的部分,以有效地结合浅层和深层特征,提高分割精度。DWFF 可用于在具有细粒度细节的区域中更重地加权浅层特征,在具有较高语义信息的区域中更重地加权深层特征,从而实现更好的特征组合和准确的分割。


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class DWFF(nn.Module):
    def __init__(self,
                 in_channels: int,
                 height: int = 2,
                 reduction: int = 8,
                 bias: bool = False) -> None:
        super(DWFF, self).__init__()

        self.height = height
        d = max(int(in_channels / reduction), 4)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(0.2)
        )
        self.fcs = nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats = inp_feats[0].shape[1]
        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
        feats_U = torch.sum(inp_feats, dim=1)
        feats_S = self.avg_pool(feats_U)
        feats_Z = self.conv_du(feats_S)
        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
        return feats_V
    

if __name__ == '__main__':
    dwff = DWFF(in_channels=64)
    x1 = torch.randn([2, 64, 16, 16])
    x2 = torch.randn([2, 64, 16, 16])
    out = dwff([x1, x2])
    print(out.shape)  # 2, 64, 16, 16

相关文章:

  • 双指针算法专题之——有效三角形的个数
  • 《Python深度学习》第二讲:深度学习的数学基础
  • 老牌软件,方便处理图片,量大管饱。
  • 4大观点直面呈现|直播回顾-DeepSeek时代的AI算力管理
  • 《灵珠觉醒:从零到算法金仙的C++修炼》卷三·天劫试炼(35)山河社稷图展开 - 编辑距离(字符串DP)
  • 向量数据库技术系列二-Milvus介绍
  • 【linux篇】--linux常见指令
  • 简单爬虫--框架
  • [蓝桥杯 2023 省 A] 买瓜 --暴力DFS+剪枝优化
  • L1-078 吉老师的回归(C++)
  • 202503执行jmeter压测数据库(ScyllaDB,redis,lindorm,Mysql)
  • 前缀和的例题
  • 麒麟系统使用-安装 SQL Developer
  • 【MIMIC数据库教程】十二、使用Python提取所有患者的高密度脂蛋白(HDL)指标
  • 【C++】 —— 笔试刷题day_6
  • [网络] socket编程--udp_echo_server
  • 深度解析前端面试八股文:核心知识点与高效应对策略
  • BigEvent项目后端学习笔记(一)用户管理模块 | 注册登录与用户信息全流程解析(含优化)
  • docker入门篇
  • 【极光 Orbit·STC8x】05. GPIO库函数驱动LED流动
  • 论文答辩ppt模板免费下载 素材/安卓优化
  • 海盐建设局网站/跨境电商网站
  • 网站源码程序/关键词检索
  • 做爰网站美女图片/业务员用什么软件找客户
  • 一个主机可以建设多少个网站/网站流量统计查询
  • 池州网站建设公司/服务网站排名咨询