注意力机制模块代码
-
被广泛推荐使用:SE、ECA、Coordinate Attention(CA)——轻量、易用且效果稳定。
-
仍可用但要考虑计算成本:BAM、GCNet、SKNet。
-
一般不建议首选(算是“过时”或逐步淘汰):Non-local、DANet,尤其在大规模、3D医学图像中不易使用。
SE模块(Squeeze-and-Excitation,通道注意力)
import torch
import torch.nn as nnclass SEBlock(nn.Module):def __init__(self, channel, reduction=16):super(SEBlock, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c) # [B, C]y = self.fc(y).view(b, c, 1, 1) # 通道注意力权重return x * y.expand_as(x)
适合插入位置:
-
卷积层后的通道注意力模块,一般放在每个卷积块或残差块的末尾;
-
编码器的每个阶段卷积输出后,对通道进行重标定。
作用:
-
通过“压缩”(Squeeze,全局平均池化)和“激励”(Excitation,两个全连接层)生成通道权重;
-
提升模型对关键通道的响应能力,抑制无关通道;
-
结构简单,参数少,容易插入。
CBAM模块(Convolutional Block Attention Module,包含通道+空间注意力)
import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, in_planes, reduction=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_planes, in_planes // reduction, bias=False),nn.ReLU(),nn.Linear(in_planes // reduction, in_planes, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()avg_out = self.fc(self.avg_pool(x).view(b, c))max_out = self.fc(self.max_pool(x).view(b, c))out = avg_out + max_outout = self.sigmoid(out).view(b, c, 1, 1)return x * out.expand_as(x)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()padding = (kernel_size - 1) // 2self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True) # 通道维平均池化max_out, _ = torch.max(x, dim=1, keepdim=True) # 通道维最大池化out = torch.cat([avg_out, max_out], dim=1) # 2通道输入out = self.conv(out)out = self.sigmoid(out)return x * outclass CBAM(nn.Module):def __init__(self, in_planes, reduction=16, kernel_size=7):super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_planes, reduction)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):out = self.channel_attention(x)out = self.spatial_attention(out)return out
-
适合插入位置:
-
卷积层之后,作为特征增强模块;
-
可插入至编码器或解码器的每个卷积块中(如UNet的每个Down或Up Block后);
-
用于桥接阶段(编码器与解码器中间),加强高层语义表达。
-
-
插入原因:
-
串联通道注意力与空间注意力,分别从“通道”和“空间位置”两个维度强化信息;
-
能抑制冗余背景区域,突出出血区域的关键通道与空间位置;
-
模块轻量、效果明显、易于嵌入到任意CNN结构。
-
Non-Local Attention(空间全局自注意力模块)
import torch
import torch.nn as nnclass NonLocalBlock(nn.Module):def __init__(self, in_channels, inter_channels=None):super(NonLocalBlock, self).__init__()self.in_channels = in_channelsself.inter_channels = inter_channels if inter_channels else in_channels // 2if self.inter_channels == 0:self.inter_channels = 1self.g = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.theta = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.phi = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.W = nn.Conv2d(self.inter_channels, in_channels, kernel_size=1)self.bn = nn.BatchNorm2d(in_channels)def forward(self, x):batch_size, C, H, W = x.size()g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [B, C', H*W]g_x = g_x.permute(0, 2, 1) # [B, H*W, C']theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # [B, C', H*W]theta_x = theta_x.permute(0, 2, 1) # [B, H*W, C']phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # [B, C', H*W]f = torch.matmul(theta_x, phi_x) # [B, H*W, H*W]f_div_C = nn.functional.softmax(f, dim=-1)y = torch.matmul(f_div_C, g_x) # [B, H*W, C']y = y.permute(0, 2, 1).contiguous() # [B, C', H*W]y = y.view(batch_size, self.inter_channels, H, W)W_y = self.W(y)W_y = self.bn(W_y)z = W_y + x # 残差连接return z
适合插入位置:
-
中间层特征图大小适中时,例如编码器中后期特征层;
-
需要捕获远距离依赖信息的地方。
作用:
-
建立空间上任意两个位置间的关系,用全局加权方式计算注意力;
-
能捕获长距离的上下文依赖,强化特征表达;
-
对脑出血图像分割帮助捕捉大范围病灶相关信息。
Transformer注意力机制中常用的多头自注意力模块(简版)
import torch
import torch.nn as nnclass MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# x shape: [batch_size, seq_len, embed_dim]batch_size, seq_len, embed_dim = x.size()qkv = self.qkv_proj(x) # [B, S, 3*E]qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, seq_len, head_dim]q, k, v = qkv[0], qkv[1], qkv[2] # 各 [B, heads, seq_len, head_dim]attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) # scaled dot productattn_weights = nn.functional.softmax(attn_scores, dim=-1)attn_output = torch.matmul(attn_weights, v) # [B, heads, seq_len, head_dim]attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)output = self.out_proj(attn_output)return output
BAM 模块(Bottleneck Attention Module 通道+空间)
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ChannelGate(nn.Module):def __init__(self, gate_channels, reduction_ratio=16):super(ChannelGate, self).__init__()self.mlp = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(gate_channels, gate_channels // reduction_ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(gate_channels // reduction_ratio, gate_channels, 1, bias=False))def forward(self, x):y = self.mlp(x)return yclass SpatialGate(nn.Module):def __init__(self, kernel_size=7):super(SpatialGate, self).__init__()self.spatial = nn.Sequential(nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2, bias=False),nn.BatchNorm2d(1))def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)y = avg_out + max_outy = self.spatial(y)return yclass BAM(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, kernel_size=7):super(BAM, self).__init__()self.channel_gate = ChannelGate(gate_channels, reduction_ratio)self.spatial_gate = SpatialGate(kernel_size)self.sigmoid = nn.Sigmoid()def forward(self, x):chn_att = self.channel_gate(x)sp_att = self.spatial_gate(x)att = self.sigmoid(chn_att + sp_att)return x * att
适合插入位置:
-
主干网络中间层的残差块后面,比如ResNet的残差块后;
-
UNet的编码器和解码器中间特征融合后;
-
跳跃连接处,对通道和空间信息进行联合加权。
原因:
-
BAM兼顾空间和通道注意力,可以帮助模型突出重要的空间区域和关键通道特征;
-
在深层网络特征较丰富后使用,能更好地强化重要信息,抑制无关信息;
-
脑出血CT图像中病灶局部显著,BAM可帮助定位病灶区域。
ECA模块(Efficient Channel Attention 通道)
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ECALayer(nn.Module):def __init__(self, channel, k_size=3):super(ECALayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size-1)//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# x: [B, C, H, W]y = self.avg_pool(x) # [B, C, 1, 1]y = y.squeeze(-1).transpose(-1, -2) # [B, 1, C]y = self.conv(y) # [B, 1, C]y = self.sigmoid(y).transpose(-1, -2).unsqueeze(-1) # [B, C, 1, 1]return x * y.expand_as(x)
适合插入位置:
-
卷积块的输出后,例如卷积层组的末尾;
-
UNet编码器各阶段卷积输出后,在通道维度做轻量级的通道权重调整;
-
轻量化网络中,用于替代复杂的SE模块。
原因:
-
ECA关注通道关系但没有过多参数,能快速提升通道特征质量;
-
脑出血CT中不同通道可能对不同病灶结构敏感,ECA能动态调整通道权重;
-
插入在特征图尺寸未大幅缩小前,效果较好。
Coordinate Attention(坐标注意力,CA)
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CoordAtt(nn.Module):def __init__(self, inp, oup, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) # 保持宽=1self.pool_w = nn.AdaptiveAvgPool2d((1, None)) # 保持高=1mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = nn.ReLU()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()# 高维方向池化x_h = self.pool_h(x) # [N, C, H, 1]x_w = self.pool_w(x).permute(0, 1, 3, 2) # [N, C, 1, W] → [N, C, W, 1]y = torch.cat([x_h, x_w], dim=2) # [N, C, H+W, 1]y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_h * a_wreturn out
适合插入位置:
-
编码器特征提取后,特别是空间维度还较大时;
-
解码器中尺度融合后,增强空间定位能力;
-
跳跃连接后,加强特征的空间位置信息。
原因:
-
CA不仅捕获通道关系,还能明确空间的长宽坐标信息,非常适合需要精确定位病灶的任务;
-
对脑出血的CT分割来说,准确捕捉空间位置很关键;
-
能提高模型对边缘和细节的感知能力。
SKNet模块(Selective Kernel )
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SKConv(nn.Module):def __init__(self, features, M=2, G=32, r=16, L=32):super(SKConv, self).__init__()d = max(int(features / r), L)self.M = M # 分支数量self.features = featuresself.convs = nn.ModuleList()for i in range(M):self.convs.append(nn.Sequential(nn.Conv2d(features, features, kernel_size=3+i*2, stride=1, padding=1+i, groups=G, bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True)))self.fc = nn.Linear(features, d)self.fcs = nn.ModuleList()for i in range(M):self.fcs.append(nn.Linear(d, features))self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.size(0)feats = []for conv in self.convs:feats.append(conv(x))feats = torch.stack(feats, dim=1) # [B, M, C, H, W]U = torch.sum(feats, dim=1) # 聚合 [B, C, H, W]s = U.mean(-1).mean(-1) # 全局平均池化 [B, C]z = self.fc(s) # [B, d]attention_vectors = []for fc in self.fcs:attention_vectors.append(fc(z).unsqueeze(1)) # [B, 1, C]attention_vectors = torch.cat(attention_vectors, dim=1) # [B, M, C]attention_vectors = self.softmax(attention_vectors) # 权重归一化attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # [B, M, C, 1, 1]out = (feats * attention_vectors).sum(dim=1) # 加权求和return out
适合插入位置:
-
卷积层之间替代标准卷积模块,作为多尺度特征提取模块;
-
编码器的中间阶段,通过多尺度卷积动态选择感受野;
-
增强不同尺度脑出血区域的检测。
原因:
-
SKNet通过动态融合不同卷积核大小的特征,适应不同尺寸的目标区域;
-
对脑出血CT图像中大小不一的出血块都能有效响应;
-
插入到卷积阶段,能更好捕获多尺度上下文信息。