计算机视觉注意力机制【一】常用注意力机制整理
在做目标检测项目,尤其是基于 YOLOv5 或 YOLOv7 的改进实验时,我发现不同注意力机制对模型性能的提升确实有明显影响,比如提高小目标检测能力、增强特征表达等。但每次找代码都得翻论文、找 GitHub,效率很低。所以我干脆把常见的注意力模块(比如 SE、CBAM、ShuffleAttention、SimAM 等)都整理到一起,统一了格式和接口,方便自己后续做结构替换和对比实验。这个整理也能帮助我更系统地理解各类注意力机制的原理和实现方式,也希望能为有类似需求的人提供一些参考。
后续会基于注意机制与 YOLO 目标检测进行融合,欢迎各位关注➕收藏
文章目录
- 🔹 1. SEAttention(Squeeze-and-Excitation Attention)
- 🔹 2. ShuffleAttention
- 🔹 3. CrissCrossAttention(CCA)
- 🔹 4. S2-MLPv2 Attention
- 🔹 5. SimAM
- 🔹 6. SKAttention(Selective Kernel)
- 🔹 7. NAMAttention(Normalization-based Attention)
- 🔹 8. SOCA(Second-order Channel Attention)
- 🔹 9. CBAM(Convolutional Block Attention Module)
- 🔹 10. GAMAttention
- 🔹 11. Coordinate attention
- 🔹 12. Efficient Channel Attention(ECA)
🔹 1. SEAttention(Squeeze-and-Excitation Attention)
来源:
https://arxiv.org/abs/1709.01507
机制:
全局平均池化 → 两层 MLP → Sigmoid → 通道权重调整
class SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)
🔹 2. ShuffleAttention
来源:
https://arxiv.org/pdf/2102.00240.pdf
机制:
通道注意力 + 空间注意力,利用 GroupNorm 和 Shuffle 操作
class ShuffleAttention(nn.Module):def __init__(self, channel=512,reduction=16,G=8):super().__init__()self.G=Gself.channel=channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid=nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()#group into subfeaturesx=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w#channel_splitx_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w#channel attentionx_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1x_channel=x_0*self.sigmoid(x_channel)#spatial attentionx_spatial=self.gn(x_1) #bs*G,c//(2*G),h,wx_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,wx_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w# concatenate along channel axisout=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,wout=out.contiguous().view(b,-1,h,w)# channel shuffleout = self.channel_shuffle(out, 2)return out
🔹 3. CrissCrossAttention(CCA)
来源: CCNet-Pure-Pytorch
机制: 分别在 H 和 W 方向做注意力交叉计算,融合上下文
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmaxdef INF(B,H,W):return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)class CrissCrossAttention(nn.Module):""" Criss-Cross Attention Module"""def __init__(self, in_dim):super(CrissCrossAttention,self).__init__()self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.softmax = Softmax(dim=3)self.INF = INFself.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):m_batchsize, _, height, width = x.size()proj_query = self.query_conv(x)proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)proj_key = self.key_conv(x)proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)proj_value = self.value_conv(x)proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)concate = self.softmax(torch.cat([energy_H, energy_W], 3))att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)#print(concate)#print(att_H) att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)#print(out_H.size(),out_W.size())return self.gamma*(out_H + out_W) + x
🔹 4. S2-MLPv2 Attention
来源:
https://arxiv.org/abs/2108.01072
机制:
Spatial Shift + MLP + 分支融合
def spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]return xclass SplitAttention(nn.Module):def __init__(self,channel=512,k=3):super().__init__()self.channel=channelself.k=kself.mlp1=nn.Linear(channel,channel,bias=False)self.gelu=nn.GELU()self.mlp2=nn.Linear(channel,channel*k,bias=False)self.softmax=nn.Softmax(1)def forward(self,x_all):b,k,h,w,c=x_all.shapex_all=x_all.reshape(b,k,-1,c) a=torch.sum(torch.sum(x_all,1),1) hat_a=self.mlp2(self.gelu(self.mlp1(a))) hat_a=hat_a.reshape(b,self.k,c) bar_a=self.softmax(hat_a) attention=bar_a.unsqueeze(-2) out=attention*x_all out=torch.sum(out,1).reshape(b,h,w,c)return outclass S2Attention(nn.Module):def __init__(self, channels=512 ):super().__init__()self.mlp1 = nn.Linear(channels,channels*3)self.mlp2 = nn.Linear(channels,channels)self.split_attention = SplitAttention()def forward(self, x):b,c,w,h = x.size()x=x.permute(0,2,3,1)x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c])x2 = spatial_shift2(x[:,:,:,c:c*2])x3 = x[:,:,:,c*2:]x_all=torch.stack([x1,x2,x3],1)a = self.split_attention(x_all)x = self.mlp2(a)x=x.permute(0,3,1,2)return x
🔹 5. SimAM
机制: 使用方差引导通道激活,无需参数
import torch
import torch.nn as nnclass SimAM(torch.nn.Module):def __init__(self, channels = None,out_channels = None, e_lambda = 1e-4):super(SimAM, self).__init__()self.activaton = nn.Sigmoid()self.e_lambda = e_lambdadef __repr__(self):s = self.__class__.__name__ + '('s += ('lambda=%f)' % self.e_lambda)return s@staticmethoddef get_module_name():return "simam"def forward(self, x):b, c, h, w = x.size()n = w * h - 1x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5 #atbriassreturn x * self.activaton(y)
🔹 6. SKAttention(Selective Kernel)
机制: 多尺度卷积 + Soft attention 选择合适卷积核
class SKAttention(nn.Module):def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):super().__init__()self.d=max(L,channel//reduction)self.convs=nn.ModuleList([])for k in kernels:self.convs.append(nn.Sequential(OrderedDict([('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),('bn',nn.BatchNorm2d(channel)),('relu',nn.ReLU())])))self.fc=nn.Linear(channel,self.d)self.fcs=nn.ModuleList([])for i in range(len(kernels)):self.fcs.append(nn.Linear(self.d,channel))self.softmax=nn.Softmax(dim=0)def forward(self, x):bs, c, _, _ = x.size()conv_outs=[]### split atbriassfor conv in self.convs:conv_outs.append(conv(x))feats=torch.stack(conv_outs,0)#k,bs,channel,h,w### fuseU=sum(conv_outs) #bs,c,h,w### reduction channelS=U.mean(-1).mean(-1) #bs,cZ=self.fc(S) #bs,d### calculate attention weightweights=[]for fc in self.fcs:weight=fc(Z)weights.append(weight.view(bs,c,1,1)) #bs,channelattention_weughts=torch.stack(weights,0)#k,bs,channel,1,1attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1### fuseV=(attention_weughts*feats).sum(0)return V
🔹 7. NAMAttention(Normalization-based Attention)
机制: 使用 BN 参数的归一化特性引导注意力
import torch.nn as nn
import torch
from torch.nn import functional as Fclass Channel_Att(nn.Module):def __init__(self, channels, t=16):super(Channel_Att, self).__init__()self.channels = channelsself.bn2 = nn.BatchNorm2d(self.channels, affine=True)def forward(self, x):residual = xx = self.bn2(x)weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())x = x.permute(0, 2, 3, 1).contiguous()x = torch.mul(weight_bn, x)x = x.permute(0, 3, 1, 2).contiguous()x = torch.sigmoid(x) * residual #return xclass NAMAttention(nn.Module):def __init__(self, channels, out_channels=None, no_spatial=True):super(NAMAttention, self).__init__()self.Channel_Att = Channel_Att(channels)def forward(self, x):x_out1=self.Channel_Att(x)return x_out1
🔹 8. SOCA(Second-order Channel Attention)
机制: 基于协方差池化和矩阵平方根的高阶通道注意力机制
import numpy as np
import torch
from torch import nn
from torch.nn import initfrom torch.autograd import Functionclass Covpool(Function):@staticmethoddef forward(ctx, input):x = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]h = x.data.shape[2]w = x.data.shape[3]M = h*wx = x.reshape(batchSize,dim,M)I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)y = x.bmm(I_hat).bmm(x.transpose(1,2))ctx.save_for_backward(input,I_hat)return y@staticmethoddef backward(ctx, grad_output):input,I_hat = ctx.saved_tensorsx = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]h = x.data.shape[2]w = x.data.shape[3]M = h*wx = x.reshape(batchSize,dim,M)grad_input = grad_output + grad_output.transpose(1,2)grad_input = grad_input.bmm(x).bmm(I_hat)grad_input = grad_input.reshape(batchSize,dim,h,w)return grad_inputclass Sqrtm(Function):@staticmethoddef forward(ctx, input, iterN):x = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]dtype = x.dtypeI3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)A = x.div(normA.view(batchSize,1,1).expand_as(x))Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device)Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1)if iterN < 2:ZY = 0.5*(I3 - A)Y[:,0,:,:] = A.bmm(ZY)else:ZY = 0.5*(I3 - A)Y[:,0,:,:] = A.bmm(ZY)Z[:,0,:,:] = ZYfor i in range(1, iterN-1):ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)ctx.save_for_backward(input, A, ZY, normA, Y, Z)ctx.iterN = iterNreturn y@staticmethoddef backward(ctx, grad_output):input, A, ZY, normA, Y, Z = ctx.saved_tensorsiterN = ctx.iterNx = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]dtype = x.dtypeder_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)if iterN < 2:der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))else:dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])for i in range(iterN-3, -1, -1):YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])dldY_ = 0.5*(dldY.bmm(YZ) - Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - ZY.bmm(dldY))dldZ_ = 0.5*(YZ.bmm(dldZ) - Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -dldZ.bmm(ZY))dldY = dldY_dldZ = dldZ_der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)for i in range(batchSize):grad_input[i,:,:] += (der_postComAux[i] \- grad_aux[i] / (normA[i] * normA[i])) \*torch.ones(dim,device = x.device).diag()return grad_input, Nonedef CovpoolLayer(var):return Covpool.apply(var)def SqrtmLayer(var, iterN):return Sqrtm.apply(var, iterN)class SOCA(nn.Module):# second-order Channel attentiondef __init__(self, channel, reduction=8):super(SOCA, self).__init__()self.max_pool = nn.MaxPool2d(kernel_size=2)self.conv_du = nn.Sequential(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),nn.ReLU(inplace=True),nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),nn.Sigmoid())def forward(self, x):batch_size, C, h, w = x.shape # x: NxCxHxWN = int(h * w)min_h = min(h, w)h1 = 1000w1 = 1000if h < h1 and w < w1:x_sub = xelif h < h1 and w > w1:W = (w - w1) // 2x_sub = x[:, :, :, W:(W + w1)]elif w < w1 and h > h1:H = (h - h1) // 2x_sub = x[:, :, H:H + h1, :]else:H = (h - h1) // 2W = (w - w1) // 2x_sub = x[:, :, H:(H + h1), W:(W + w1)]cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layercov_mat_sqrt = SqrtmLayer(cov_mat,5) # Matrix square root layer( including pre-norm,Newton-Schulz iter. and post-com. with 5 iteration)cov_mat_sum = torch.mean(cov_mat_sqrt,1)cov_mat_sum = cov_mat_sum.view(batch_size,C,1,1)y_cov = self.conv_du(cov_mat_sum)return y_cov*x
🔹 9. CBAM(Convolutional Block Attention Module)
机制: Channel Attention + Spatial Attention 串联使用
class ChannelAttentionModule(nn.Module):def __init__(self, c1, reduction=16):super(ChannelAttentionModule, self).__init__()mid_channel = c1 // reductionself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Linear(in_features=c1, out_features=mid_channel),nn.LeakyReLU(0.1, inplace=True),nn.Linear(in_features=mid_channel, out_features=c1))self.act = nn.Sigmoid()#self.act=nn.SiLU()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)maxout = self.shared_MLP(self.max_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)return self.act(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.act = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.act(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, c1,c2):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(c1)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn out
🔹 10. GAMAttention
原理图:
import numpy as np
import torch
from torch import nn
from torch.nn import initclass GAMAttention(nn.Module):#https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True,rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(c1 /rate)),nn.ReLU(inplace=True),nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle out = x * x_spatial_attreturn out def channel_shuffle(x, groups=2):B, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out=out.view(B, C, H, W) return out
🔹 11. Coordinate attention
class h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)
class CA(nn.Module):# Coordinate Attention for Efficient Mobile Network Design'''Recent studies on mobile network design have demonstrated the remarkable effectiveness of channel attention (e.g., the Squeeze-and-Excitation attention) for liftingmodel performance, but they generally neglect the positional information, which is important for generating spatially selective attention maps. In this paper, we propose anovel attention mechanism for mobile iscyy networks by embedding positional information into channel attention, whichwe call “coordinate attention”. Unlike channel attentionthat transforms a feature tensor to a single feature vector iscyy via 2D global pooling, the coordinate attention factorizes channel attention into two 1D feature encoding processes that aggregate features along the two spatial directions, respectively'''def __init__(self, inp, oup, reduction=32):super(CA, self).__init__()mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn,c,h,w = x.size()pool_h = nn.AdaptiveAvgPool2d((h, 1))pool_w = nn.AdaptiveAvgPool2d((1, w))x_h = pool_h(x)x_w = pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn out
🔹 12. Efficient Channel Attention(ECA)
import torch.nn as nn
import torch
from torch.nn import functional as Fclass ECAttention(nn.Module):"""Constructs a ECA module.Args:channel: Number of channels of the input feature mapk_size: Adaptive selection of kernel size automg"""def __init__(self, c1,c2, k_size=3):super(ECAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)