05 - SimAM模块
论文《SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks》
1、作用
SimAM(Simple Attention Module)提出了一个概念简单但非常有效的注意力模块,用于卷积神经网络。与现有的通道维度和空间维度注意力模块不同,SimAM能够为特征图中的每个神经元推断出3D注意力权重,而无需在原始网络中添加参数。
2、机制
1、能量函数优化:
SimAM基于著名的神经科学理论,通过优化一个能量函数来找出每个神经元的重要性。这个过程不添加任何新参数到原始网络中。
2、快速闭合形式解决方案:
对于能量函数,SimAM推导出了一个快速的闭合形式解决方案,并展示了这个解决方案可以在不到十行代码中实现。这种方法避免了结构调整的繁琐工作,使模块的设计更为简洁高效。
3、独特优势
1、无参数设计:
SimAM的一个显著优势是它不增加任何额外的参数。这使得SimAM可以轻松地集成到任何现有的CNN架构中,几乎不增加计算成本。
2、直接生成3D权重:
与大多数现有的注意力模块不同,SimAM能够直接为每个神经元生成真正的3D权重,而不是仅仅在通道或空间维度上。这种全面的注意力机制能够更精确地捕捉到重要的特征信息。
3、基于神经科学的设计:
SimAM的设计灵感来自于人类大脑中的注意力机制,尤其是空间抑制现象,使其在捕获视觉任务中的关键信息方面更为高效和自然。
4、代码
import torch
import torch.nn as nn
from thop import profile # 引入thop库来计算模型的FLOPs和参数数量# 定义SimAM模块
class Simam_module(torch.nn.Module):def __init__(self, e_lambda=1e-4):super(Simam_module, self).__init__()self.act = nn.Sigmoid() # 使用Sigmoid激活函数self.e_lambda = e_lambda # 定义平滑项e_lambda,防止分母为0def forward(self, x):b, c, h, w = x.size() # 获取输入x的尺寸n = w * h - 1 # 计算特征图的元素数量减一,用于下面的归一化# 计算输入特征x与其均值之差的平方x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)# 计算注意力权重y,这里实现了SimAM的核心计算公式y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5# 返回经过注意力加权的输入特征return x * self.act(y)# 示例使用
if __name__ == '__main__':model = Simam_module().cuda() # 实例化SimAM模块并移到GPU上x = torch.randn(1, 3, 64, 64).cuda() # 创建一个随机输入并移到GPU上y = model(x) # 将输入传递给模型print(y.size()) # 打印输出尺寸# 使用thop库计算模型的FLOPs和参数数量flops, params = profile(model, (x,))print(flops / 1e9) # 打印以Giga FLOPs为单位的浮点操作数print(params) # 打印模型参数数量