当前位置: 首页 > news >正文

注意力机制模块代码

  • 被广泛推荐使用:SE、ECA、Coordinate Attention(CA)——轻量、易用且效果稳定。

  • 仍可用但要考虑计算成本:BAM、GCNet、SKNet。

  • 一般不建议首选(算是“过时”或逐步淘汰):Non-local、DANet,尤其在大规模、3D医学图像中不易使用。

 

SE模块(Squeeze-and-Excitation,通道注意力)

import torch
import torch.nn as nnclass SEBlock(nn.Module):def __init__(self, channel, reduction=16):super(SEBlock, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)  # [B, C]y = self.fc(y).view(b, c, 1, 1)  # 通道注意力权重return x * y.expand_as(x)

适合插入位置:

  • 卷积层后的通道注意力模块,一般放在每个卷积块或残差块的末尾;

  • 编码器的每个阶段卷积输出后,对通道进行重标定。

作用:

  • 通过“压缩”(Squeeze,全局平均池化)和“激励”(Excitation,两个全连接层)生成通道权重;

  • 提升模型对关键通道的响应能力,抑制无关通道;

  • 结构简单,参数少,容易插入。

CBAM模块(Convolutional Block Attention Module,包含通道+空间注意力)

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, in_planes, reduction=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_planes, in_planes // reduction, bias=False),nn.ReLU(),nn.Linear(in_planes // reduction, in_planes, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()avg_out = self.fc(self.avg_pool(x).view(b, c))max_out = self.fc(self.max_pool(x).view(b, c))out = avg_out + max_outout = self.sigmoid(out).view(b, c, 1, 1)return x * out.expand_as(x)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()padding = (kernel_size - 1) // 2self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)  # 通道维平均池化max_out, _ = torch.max(x, dim=1, keepdim=True)  # 通道维最大池化out = torch.cat([avg_out, max_out], dim=1)  # 2通道输入out = self.conv(out)out = self.sigmoid(out)return x * outclass CBAM(nn.Module):def __init__(self, in_planes, reduction=16, kernel_size=7):super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_planes, reduction)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):out = self.channel_attention(x)out = self.spatial_attention(out)return out
  • 适合插入位置

    • 卷积层之后,作为特征增强模块;

    • 可插入至编码器或解码器的每个卷积块中(如UNet的每个Down或Up Block后);

    • 用于桥接阶段(编码器与解码器中间),加强高层语义表达。

  • 插入原因

    • 串联通道注意力与空间注意力,分别从“通道”和“空间位置”两个维度强化信息;

    • 能抑制冗余背景区域,突出出血区域的关键通道与空间位置;

    • 模块轻量、效果明显、易于嵌入到任意CNN结构。

Non-Local Attention(空间全局自注意力模块)

import torch
import torch.nn as nnclass NonLocalBlock(nn.Module):def __init__(self, in_channels, inter_channels=None):super(NonLocalBlock, self).__init__()self.in_channels = in_channelsself.inter_channels = inter_channels if inter_channels else in_channels // 2if self.inter_channels == 0:self.inter_channels = 1self.g = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.theta = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.phi = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.W = nn.Conv2d(self.inter_channels, in_channels, kernel_size=1)self.bn = nn.BatchNorm2d(in_channels)def forward(self, x):batch_size, C, H, W = x.size()g_x = self.g(x).view(batch_size, self.inter_channels, -1)  # [B, C', H*W]g_x = g_x.permute(0, 2, 1)  # [B, H*W, C']theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)  # [B, C', H*W]theta_x = theta_x.permute(0, 2, 1)  # [B, H*W, C']phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)  # [B, C', H*W]f = torch.matmul(theta_x, phi_x)  # [B, H*W, H*W]f_div_C = nn.functional.softmax(f, dim=-1)y = torch.matmul(f_div_C, g_x)  # [B, H*W, C']y = y.permute(0, 2, 1).contiguous()  # [B, C', H*W]y = y.view(batch_size, self.inter_channels, H, W)W_y = self.W(y)W_y = self.bn(W_y)z = W_y + x  # 残差连接return z

适合插入位置:

  • 中间层特征图大小适中时,例如编码器中后期特征层;

  • 需要捕获远距离依赖信息的地方

作用:

  • 建立空间上任意两个位置间的关系,用全局加权方式计算注意力;

  • 能捕获长距离的上下文依赖,强化特征表达;

  • 对脑出血图像分割帮助捕捉大范围病灶相关信息。

Transformer注意力机制中常用的多头自注意力模块(简版)

import torch
import torch.nn as nnclass MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# x shape: [batch_size, seq_len, embed_dim]batch_size, seq_len, embed_dim = x.size()qkv = self.qkv_proj(x)  # [B, S, 3*E]qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, seq_len, head_dim]q, k, v = qkv[0], qkv[1], qkv[2]  # 各 [B, heads, seq_len, head_dim]attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # scaled dot productattn_weights = nn.functional.softmax(attn_scores, dim=-1)attn_output = torch.matmul(attn_weights, v)  # [B, heads, seq_len, head_dim]attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)output = self.out_proj(attn_output)return output

BAM 模块(Bottleneck Attention Module 通道+空间)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ChannelGate(nn.Module):def __init__(self, gate_channels, reduction_ratio=16):super(ChannelGate, self).__init__()self.mlp = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(gate_channels, gate_channels // reduction_ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(gate_channels // reduction_ratio, gate_channels, 1, bias=False))def forward(self, x):y = self.mlp(x)return yclass SpatialGate(nn.Module):def __init__(self, kernel_size=7):super(SpatialGate, self).__init__()self.spatial = nn.Sequential(nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2, bias=False),nn.BatchNorm2d(1))def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)y = avg_out + max_outy = self.spatial(y)return yclass BAM(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, kernel_size=7):super(BAM, self).__init__()self.channel_gate = ChannelGate(gate_channels, reduction_ratio)self.spatial_gate = SpatialGate(kernel_size)self.sigmoid = nn.Sigmoid()def forward(self, x):chn_att = self.channel_gate(x)sp_att = self.spatial_gate(x)att = self.sigmoid(chn_att + sp_att)return x * att

适合插入位置:

  • 主干网络中间层的残差块后面,比如ResNet的残差块后;

  • UNet的编码器和解码器中间特征融合后

  • 跳跃连接处,对通道和空间信息进行联合加权。

原因:

  • BAM兼顾空间和通道注意力,可以帮助模型突出重要的空间区域和关键通道特征;

  • 在深层网络特征较丰富后使用,能更好地强化重要信息,抑制无关信息;

  • 脑出血CT图像中病灶局部显著,BAM可帮助定位病灶区域。

ECA模块(Efficient Channel Attention 通道)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ECALayer(nn.Module):def __init__(self, channel, k_size=3):super(ECALayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size-1)//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# x: [B, C, H, W]y = self.avg_pool(x)  # [B, C, 1, 1]y = y.squeeze(-1).transpose(-1, -2)  # [B, 1, C]y = self.conv(y)  # [B, 1, C]y = self.sigmoid(y).transpose(-1, -2).unsqueeze(-1)  # [B, C, 1, 1]return x * y.expand_as(x)

适合插入位置:

  • 卷积块的输出后,例如卷积层组的末尾;

  • UNet编码器各阶段卷积输出后,在通道维度做轻量级的通道权重调整;

  • 轻量化网络中,用于替代复杂的SE模块

原因:

  • ECA关注通道关系但没有过多参数,能快速提升通道特征质量;

  • 脑出血CT中不同通道可能对不同病灶结构敏感,ECA能动态调整通道权重;

  • 插入在特征图尺寸未大幅缩小前,效果较好。

Coordinate Attention(坐标注意力,CA)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CoordAtt(nn.Module):def __init__(self, inp, oup, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))  # 保持宽=1self.pool_w = nn.AdaptiveAvgPool2d((1, None))  # 保持高=1mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = nn.ReLU()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()# 高维方向池化x_h = self.pool_h(x)  # [N, C, H, 1]x_w = self.pool_w(x).permute(0, 1, 3, 2)  # [N, C, 1, W] → [N, C, W, 1]y = torch.cat([x_h, x_w], dim=2)  # [N, C, H+W, 1]y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_h * a_wreturn out

适合插入位置:

  • 编码器特征提取后,特别是空间维度还较大时;

  • 解码器中尺度融合后,增强空间定位能力;

  • 跳跃连接后,加强特征的空间位置信息。

原因:

  • CA不仅捕获通道关系,还能明确空间的长宽坐标信息,非常适合需要精确定位病灶的任务;

  • 对脑出血的CT分割来说,准确捕捉空间位置很关键;

  • 能提高模型对边缘和细节的感知能力。

SKNet模块(Selective Kernel  )

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SKConv(nn.Module):def __init__(self, features, M=2, G=32, r=16, L=32):super(SKConv, self).__init__()d = max(int(features / r), L)self.M = M  # 分支数量self.features = featuresself.convs = nn.ModuleList()for i in range(M):self.convs.append(nn.Sequential(nn.Conv2d(features, features, kernel_size=3+i*2, stride=1, padding=1+i, groups=G, bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True)))self.fc = nn.Linear(features, d)self.fcs = nn.ModuleList()for i in range(M):self.fcs.append(nn.Linear(d, features))self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.size(0)feats = []for conv in self.convs:feats.append(conv(x))feats = torch.stack(feats, dim=1)  # [B, M, C, H, W]U = torch.sum(feats, dim=1)  # 聚合 [B, C, H, W]s = U.mean(-1).mean(-1)  # 全局平均池化 [B, C]z = self.fc(s)  # [B, d]attention_vectors = []for fc in self.fcs:attention_vectors.append(fc(z).unsqueeze(1))  # [B, 1, C]attention_vectors = torch.cat(attention_vectors, dim=1)  # [B, M, C]attention_vectors = self.softmax(attention_vectors)  # 权重归一化attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)  # [B, M, C, 1, 1]out = (feats * attention_vectors).sum(dim=1)  # 加权求和return out

适合插入位置:

  • 卷积层之间替代标准卷积模块,作为多尺度特征提取模块;

  • 编码器的中间阶段,通过多尺度卷积动态选择感受野;

  • 增强不同尺度脑出血区域的检测

原因:

  • SKNet通过动态融合不同卷积核大小的特征,适应不同尺寸的目标区域;

  • 对脑出血CT图像中大小不一的出血块都能有效响应;

  • 插入到卷积阶段,能更好捕获多尺度上下文信息。

相关文章:

  • windows中Redis、MySQL 和 Elasticsearch启动并正确监听指定端口
  • 实时数仓flick+clickhouse启动命令
  • 聊一聊 .NET Dump 中的 Linux信号机制
  • Spark SQL进阶:解锁大数据处理的新姿势
  • 编程规范Summary
  • C++ STL stack容器使用详解
  • 2025 年江西研究生数学建模竞赛题A题电动汽车充电桩共享优化与电网安全协同模型完整思路 模型代码 结果 成品分享
  • 浙大版《Python 程序设计》题目集6-3,6-4,6-5,6-6列表或元组的数字元素求和及其变式(递归解法)
  • C++11 中引入的`final` 关键字作用。
  • python处理signal(信号)
  • 8种使用克劳德4的方法,目前可用随时更新!
  • Map集合(双列集合)
  • Qt QPaintEvent绘图事件painter使用指南
  • lcd-framebuffer驱动开发参考文章
  • 外卖霸王餐支持京东外卖点餐啦~
  • 零基础远程连接课题组Linux服务器,安装anaconda,配置python环境(换源),在服务器上运行python代码【1/3 适合小白,步骤详细!!!】
  • 高防CDN如何解决网站访问卡顿与崩溃问题?
  • 如何提高CAD作图设计效率,技术分享
  • Memory Repair (一)
  • 【报错解决方案】cannot open shared object file: No such file or directory
  • 建设工程质量安全管理协会网站/农产品网络营销推广方案
  • 重庆江北网站建设/ttkefu在线客服系统官网
  • seo技术培训价格表/网站快速排名优化报价
  • 哪个网站可以学做包子/郑州网站推广公司排名
  • 湖南省建设厅城乡建设网站/云浮seo
  • 网站规划书的内容有哪些/知乎关键词排名