当前位置: 首页 > news >正文

ECA注意力机制改进思路

摘要回顾

近年来,通道注意机制在提高深度卷积神经网络(cnn)性能方面显示出巨大的潜力。 然而,大多数现有方法致力于开发更复杂的注意力模块以获得更好的性能,这不可避免地增加了模型的复杂性。 为了克服性能和复杂性权衡的矛盾,提出了一种高效通道注意(ECA)模块,该模块只涉及少量参数,但能带来明显的性能增益。 通过对SENet中信道注意模块的分析,我们通过经验证明避免降维对于学习信道注意非常重要,适当的跨信道交互可以在显著降低模型复杂性的同时保持性能。 因此,我们提出了一种不降维的局部跨通道交互策略,该策略可以通过一维卷积有效地实现。 此外,我们还开发了一种自适应选择一维卷积核大小的方法,以确定局部跨通道相互作用的覆盖范围。 所提出的ECA模块是高效而有效的,例如,我们的模块对ResNet50主干的参数和计算分别为80 vs. 24.37M和4.7e-4 GFLOPs vs. 3.86 GFLOPs,就top -1精度而言,性能提升超过2%。 我们广泛评估了我们的ECA模块在图像分类,目标检测和实例分割与resnets和MobileNetV2的主干。 实验结果表明,该模块效率更高,性能优于同类模块。

性能对比图

方法思路

通过分析降维效应和跨通道交互作用对SE阻塞进行实证诊断。 这促使我们提出我们的ECA模块。 此外,我们还开发了一种自适应确定ECA参数的方法,最后展示了如何将其应用于深度cnn。

原始代码 

import torch, math
from torch import nn

class EfficientChannelAttention(nn.Module):   
    def __init__(self, c, b=1, gamma=2):
        super(EfficientChannelAttention, self).__init__()
        t = int(abs((math.log(c, 2) + b) / gamma))
        k = t if t % 2 else t + 1

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        # batchsize = 1,channel = c,h,w
        out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        out = self.sigmoid(out)
        return out * x

改进思路

1.动态核尺寸 + 多尺度卷积融合 +动态权重融合


import torch
import math
from torch import nn

class DynamicMultiKernelECA(nn.Module):
    def __init__(self, c, gamma_range=(2, 8), base_b=1):
        super().__init__()
        self.gamma_min, self.gamma_max = gamma_range
        self.base_b = base_b

        # 多尺度卷积组
        self.conv3 = nn.Conv1d(c, c, kernel_size=3, padding=1, groups=c, bias=False)
        self.conv5 = nn.Conv1d(c, c, kernel_size=5, padding=2, groups=c, bias=False)
        self.conv7 = nn.Conv1d(c, c, kernel_size=7, padding=3, groups=c, bias=False)

        # 动态权重生成
        self.weight_net = nn.Sequential(
            nn.Linear(c, 3),
            nn.Softmax(dim=1)
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()

        # 动态计算gamma值
        gamma = self.gamma_min + (self.gamma_max - self.gamma_min) * torch.sigmoid(
            torch.mean(x, dim=[2,3], keepdim=True))

        # 修正维度变换
        y = self.avg_pool(x).squeeze(-1)  # Shape: (b, c, 1)

        # 多分支卷积
        conv3_out = self.conv3(y)  # Shape: (b, c, 1)
        conv5_out = self.conv5(y)
        conv7_out = self.conv7(y)

        # 动态权重融合
        weights = self.weight_net(y.squeeze(-1))  # 输入形状 (b, c)
        combined = (
            weights[:, 0].unsqueeze(-1).unsqueeze(-1) * conv3_out +
            weights[:, 1].unsqueeze(-1).unsqueeze(-1) * conv5_out +
            weights[:, 2].unsqueeze(-1).unsqueeze(-1) * conv7_out
        )

        # 空间重构
        out = self.sigmoid(combined).view(b, c, 1, 1)
        return x * out.expand_as(x)

if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    eca = DynamicMultiKernelECA(c=512)
    output = eca(input)
    print(f"输入形状: {input.shape} → 输出形状: {output.shape}")
    # 输出应为 torch.Size([50, 512, 7, 7])

2.双向时序注意增强

class BidirectionalECA(nn.Module):
    def __init__(self, c, groups=8):
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # 双向卷积
        self.conv_forward = nn.Conv1d(1, 1, kernel_size=3, padding=1, groups=1, bias=False)
        self.conv_backward = nn.Conv1d(1, 1, kernel_size=3, padding=1, groups=1, bias=False)

        # 门控机制
        self.gate = nn.Sequential(
            nn.Linear(c, c // 16),
            nn.ReLU(),
            nn.Linear(c // 16, 2),  # 输出2维权重
            nn.Softmax(dim=1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()

        # 特征压缩
        y = self.avg_pool(x).squeeze(-1).permute(0, 2, 1)  # (b, 1, c)

        # 前向传播
        forward = self.conv_forward(y).sigmoid()  # (b, 1, c)

        # 逆向传播
        reversed_y = torch.flip(y, dims=[-1])
        backward = torch.flip(self.conv_backward(reversed_y), dims=[-1]).sigmoid()

        # 门控融合(维度对齐)
        gate_weights = self.gate(y.mean(dim=1))  # (b, 2)
        combined = (gate_weights[:, 0].view(b, 1, 1) * forward +
                    gate_weights[:, 1].view(b, 1, 1) * backward)

        # 空间重构
        out = self.sigmoid(combined).permute(0, 2, 1).unsqueeze(-1)  # (b, c, 1, 1)
        return x * out.expand_as(x)  

if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    eca = BidirectionalECA(c=512)
    output = eca(input)
    print(f"输入形状: {input.shape} → 输出形状: {output.shape}")
    # 输出应为 torch.Size([50, 512, 7, 7])

相关文章:

  • 第三章-PHP流程控制语句
  • Linux 运行级别
  • 带宽管理配置实验
  • 【Azure 架构师学习笔记】- Azure Databricks (21) --费用相关
  • 进程管理:前后台切换
  • 3U VPX 国产化板卡FT6678+V7 690T
  • 格式化输出备忘
  • css的显示模式
  • fs的proxy_media模式失效
  • 网络安全 与 加密算法
  • ngx_command_t
  • Spring Cloud LoadBalancer 原理与实践
  • 网络安全——SpringBoot配置文件明文加密
  • 三相逆变器不控整流场景简要分析
  • 【6】拓扑排序学习笔记
  • 什么是 Redis
  • 【QT】】qcustomplot的使用
  • leecode797.所有可能的路径
  • WPF窗口读取、显示、修改、另存excel文件——CAD c#二次开发
  • TEXT()的作用
  • 专利申请全球领先!去年我国卫星导航与位置服务产值超5700亿
  • 竞彩湃|足总杯决赛或有冷门,德甲欧冠资格之争谁笑到最后
  • 孟夏韵评《无序的学科》丨误读与重构的文化漂流
  • 中国首艘海洋级智能科考船“同济”号试航成功,可搭载水下遥控机器人
  • 四川内江警方通报一起持刀伤人致死案:因车辆停放引起,嫌犯被抓获
  • 丹麦外交大臣拉斯穆森将访华