当前位置: 首页 > news >正文

【Attention】SKAttention

SKAttention选择核注意力

标题:SKAttention

期刊:IEEE2019

代码: https://github.com/implus/SKNet

简介:

  • 动机:增大感受野来提升性能、多尺度信息聚合方式
  • 解决的问题:自适应调整感受野大小
  • 创新性:提出选择性内核(SK)卷积softmax来进行自适应选择

模型结构

在这里插入图片描述

模型代码

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict

# Selective Kernel Attention
class SKAttention(nn.Module):

    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        # 中间维度d的计算
        self.d = max(L, channel // reduction)
        # 多分支卷积层(使用不同尺寸的卷积核)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    # 分组卷积(输入输出通道数相同,保持维度)
                    ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
                    # 批归一化(保持维度)  
                    ('bn', nn.BatchNorm2d(channel)),
                    # ReLU激活函数
                    ('relu', nn.ReLU())
                ]))
            )
        # # 通道压缩层(全连接层)
        self.fc = nn.Linear(channel, self.d)
        # 多分支注意力权重生成层
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, channel))
        # 注意力权重归一化(沿分支维度softmax)
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):# 输入x形状: [B, C, H, W]
        bs, c, _, _ = x.size() # 获取输入的batch_size, 通道数, 高度, 宽度
        conv_outs = []
        ### Split阶段:多分支特征提取
        for conv in self.convs:
            conv_outs.append(conv(x)) # 每个分支输出: [B, C, H, W]
        feats = torch.stack(conv_outs, 0)  # 堆叠后形状: [K, B, C, H, W](K是kernel数量)

        ### Fuse阶段:特征融合
        U = sum(conv_outs) # 逐元素相加 → [B, C, H, W]

        ### Channel Reduction:通道压缩
        S = U.mean(-1).mean(-1)  # 空间全局平均池化 → [B, C,1,1]
        Z = self.fc(S)   # 全连接层降维 → [B, d](d=self.d)

        ### 计算注意力权重
        weights = []
        for fc in self.fcs: #  每个kernel对应一个全连接层
            weight = fc(Z) # 全连接层输出 → [B, C]
            weights.append(weight.view(bs, c, 1, 1))  # 调整形状 → [B, C, 1, 1]
        attention_weughts = torch.stack(weights, 0)   # 堆叠 → [K, B, C, 1, 1]
        attention_weughts = self.softmax(attention_weughts)  # 沿K维度softmax归一化

        ### fuse
        V = (attention_weughts * feats).sum(0) # 加权求和 → [B, C, H, W]
        return V


if __name__ == '__main__':
    input = torch.rand(1,64,256,256).cuda()
    model = SKAttention(channel=64, reduction=8).cuda()
    output = model (input)
    print('input_size:', input.size())
    print('output_size:', output.size())
    print("最大内存占用:", torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")
    

相关文章:

  • 优先队列-小根堆留坑
  • 使用 Node.js 读取 Excel 文件并处理合并单元格
  • Spring:AOP
  • 网络HTTPS协议
  • SOFABoot-08-启动加速
  • 修改服务器windows远程桌面默认端口号
  • 苹果iPhone屏幕防护专利获批,未来iPhone或更耐用
  • Linux 通过压缩包安装 MySQL 并设置远程连接教程
  • Nginx及前端部署全流程:初始化配置到生产环境部署(附Nginx常用命令)
  • I/O 多路复用(I/O Multiplexing)
  • Java面试黄金宝典9
  • Linux | ubuntu安装 SSH 软件及测试工具
  • 挂谷猜想的证明错误百出
  • 嵌入式基础知识学习:SPI通信协议是什么?
  • 【趣谈】了解语音拼写检查算法的内部机制
  • PTA团体程序设计天梯赛-练习集71-75题
  • 2025年渗透测试面试题总结-某深信服 -安全工程师(题目+回答)
  • 关于转嵌入式的一点想法
  • 不做颠覆者,甘为连接器,在技术叠层中培育智能新物种
  • 蓝桥杯(N皇后问题)------回溯法
  • 外卖员投资失败负疚离家流浪,经民警劝回后泣不成声给父母下跪
  • 越秀地产前4个月销售额约411.2亿元,达年度销售目标的34.1%
  • 罗氏制药全新生物制药生产基地投资项目在沪启动:预计投资20.4亿元,2031年投产
  • 复旦设立新文科发展基金,校友曹国伟、王长田联合捐赠1亿助力人文学科与社会科学创新
  • 一季度全国消协组织为消费者挽回经济损失23723万元
  • 对话|蓬皮杜策展人布莱昂:抽象风景中的中国审美