09 - TripletAttention模块
论文《Rotate to Attend: Convolutional Triplet Attention Module》
1、作用
Triplet Attention是一种新颖的注意力机制,它通过捕获跨维度交互,利用三分支结构来计算注意力权重。对于输入张量,Triplet Attention通过旋转操作建立维度间的依赖关系,随后通过残差变换对信道和空间信息进行编码,实现了几乎不增加计算成本的情况下,有效增强视觉表征的能力。
2、机制
1、三分支结构:
Triplet Attention包含三个分支,每个分支负责捕获输入的空间维度H或W与信道维度C之间的交互特征。
2、跨维度交互:
通过在每个分支中对输入张量进行排列(permute)操作,并通过Z-pool和k×k的卷积层处理,以捕获跨维度的交互特征。
3、注意力权重的生成:
利用sigmoid激活层生成注意力权重,并应用于排列后的输入张量,然后将其排列回原始输入形状。
3、 独特优势
1、跨维度交互:
Triplet Attention通过捕获输入张量的跨维度交互,提供了丰富的判别特征表征,较之前的注意力机制(如SENet、CBAM等)能够更有效地增强网络的性能。
2、几乎无计算成本增加:
相比于传统的注意力机制,Triplet Attention在提升网络性能的同时,几乎不增加额外的计算成本和参数数量,使得它可以轻松地集成到经典的骨干网络中。
3、无需降维:
与其他注意力机制不同,Triplet Attention不进行维度降低处理,这避免了因降维可能导致的信息丢失,保证了信道与权重间的直接对应关系。
总的来说,Triplet Attention通过其独特的三分支结构和跨维度交互机制,在提高模型性能的同时,保持了计算效率,显示了其在各种视觉任务中的应用潜力。
4、代码
import torch
import torch.nn as nn# 定义一个基本的卷积模块,包括卷积、批归一化和ReLU激活
class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planes# 定义卷积层self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)# 条件性地添加批归一化层self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None# 条件性地添加ReLU激活函数self.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x) # 应用卷积if self.bn is not None:x = self.bn(x) # 应用批归一化if self.relu is not None:x = self.relu(x) # 应用ReLUreturn x# 定义ZPool模块,结合最大池化和平均池化结果
class ZPool(nn.Module):def forward(self, x):# 结合最大值和平均值return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)# 定义注意力门,用于根据输入特征生成注意力权重
class AttentionGate(nn.Module):def __init__(self):super(AttentionGate, self).__init__()kernel_size = 7 # 设定卷积核大小self.compress = ZPool() # 使用ZPool模块self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False) # 通过卷积调整通道数def forward(self, x):x_compress = self.compress(x) # 应用ZPoolx_out = self.conv(x_compress) # 通过卷积生成注意力权重scale = torch.sigmoid_(x_out) # 应用Sigmoid激活return x * scale # 将注意力权重乘以原始特征# 定义TripletAttention模块,结合了三种不同方向的注意力门
class TripletAttention(nn.Module):def __init__(self, no_spatial=False):super(TripletAttention, self).__init__()self.cw = AttentionGate() # 定义宽度方向的注意力门self.hc = AttentionGate() # 定义高度方向的注意力门self.no_spatial = no_spatial # 是否忽略空间注意力if not no_spatial:self.hw = AttentionGate() # 定义空间方向的注意力门def forward(self, x):# 应用注意力门并结合结果x_perm1 = x.permute(0, 2, 1, 3).contiguous() # 转置以应用宽度方向的注意力x_out1 = self.cw(x_perm1)x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() # 还原转置x_perm2 = x.permute(0, 3, 2, 1).contiguous() # 转置以应用高度方向的注意力x_out2 = self.hc(x_perm2)x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() # 还原转置if not self.no_spatial:x_out = self.hw(x) # 应用空间注意力x_out = 1 / 3 * (x_out + x_out11 + x_out21) # 结合三个方向的结果else:x_out = 1 / 2 * (x_out11 + x_out21) # 结合两个方向的结果(如果no_spatial为True)return x_out# 示例代码
if __name__ == '__main__':input = torch.randn(50, 512, 7, 7) # 生成随机输入triplet = TripletAttention() # 实例化TripletAttentionoutput = triplet(input) # 应用TripletAttentionprint(output.shape) # 打印输出形状