当前位置: 首页 > news >正文

09 - TripletAttention模块

论文《Rotate to Attend: Convolutional Triplet Attention Module》

1、作用

Triplet Attention是一种新颖的注意力机制,它通过捕获跨维度交互,利用三分支结构来计算注意力权重。对于输入张量,Triplet Attention通过旋转操作建立维度间的依赖关系,随后通过残差变换对信道和空间信息进行编码,实现了几乎不增加计算成本的情况下,有效增强视觉表征的能力。

2、机制

1、三分支结构

Triplet Attention包含三个分支,每个分支负责捕获输入的空间维度H或W与信道维度C之间的交互特征。

2、跨维度交互

通过在每个分支中对输入张量进行排列(permute)操作,并通过Z-pool和k×k的卷积层处理,以捕获跨维度的交互特征。

3、注意力权重的生成

利用sigmoid激活层生成注意力权重,并应用于排列后的输入张量,然后将其排列回原始输入形状。

3、 独特优势

1、跨维度交互

Triplet Attention通过捕获输入张量的跨维度交互,提供了丰富的判别特征表征,较之前的注意力机制(如SENet、CBAM等)能够更有效地增强网络的性能。

2、几乎无计算成本增加

相比于传统的注意力机制,Triplet Attention在提升网络性能的同时,几乎不增加额外的计算成本和参数数量,使得它可以轻松地集成到经典的骨干网络中。

3、无需降维

与其他注意力机制不同,Triplet Attention不进行维度降低处理,这避免了因降维可能导致的信息丢失,保证了信道与权重间的直接对应关系。

总的来说,Triplet Attention通过其独特的三分支结构和跨维度交互机制,在提高模型性能的同时,保持了计算效率,显示了其在各种视觉任务中的应用潜力。

4、代码

import torch
import torch.nn as nn# 定义一个基本的卷积模块,包括卷积、批归一化和ReLU激活
class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planes# 定义卷积层self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)# 条件性地添加批归一化层self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None# 条件性地添加ReLU激活函数self.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)  # 应用卷积if self.bn is not None:x = self.bn(x)  # 应用批归一化if self.relu is not None:x = self.relu(x)  # 应用ReLUreturn x# 定义ZPool模块,结合最大池化和平均池化结果
class ZPool(nn.Module):def forward(self, x):# 结合最大值和平均值return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)# 定义注意力门,用于根据输入特征生成注意力权重
class AttentionGate(nn.Module):def __init__(self):super(AttentionGate, self).__init__()kernel_size = 7  # 设定卷积核大小self.compress = ZPool()  # 使用ZPool模块self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)  # 通过卷积调整通道数def forward(self, x):x_compress = self.compress(x)  # 应用ZPoolx_out = self.conv(x_compress)  # 通过卷积生成注意力权重scale = torch.sigmoid_(x_out)  # 应用Sigmoid激活return x * scale  # 将注意力权重乘以原始特征# 定义TripletAttention模块,结合了三种不同方向的注意力门
class TripletAttention(nn.Module):def __init__(self, no_spatial=False):super(TripletAttention, self).__init__()self.cw = AttentionGate()  # 定义宽度方向的注意力门self.hc = AttentionGate()  # 定义高度方向的注意力门self.no_spatial = no_spatial  # 是否忽略空间注意力if not no_spatial:self.hw = AttentionGate()  # 定义空间方向的注意力门def forward(self, x):# 应用注意力门并结合结果x_perm1 = x.permute(0, 2, 1, 3).contiguous()  # 转置以应用宽度方向的注意力x_out1 = self.cw(x_perm1)x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()  # 还原转置x_perm2 = x.permute(0, 3, 2, 1).contiguous()  # 转置以应用高度方向的注意力x_out2 = self.hc(x_perm2)x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()  # 还原转置if not self.no_spatial:x_out = self.hw(x)  # 应用空间注意力x_out = 1 / 3 * (x_out + x_out11 + x_out21)  # 结合三个方向的结果else:x_out = 1 / 2 * (x_out11 + x_out21)  # 结合两个方向的结果(如果no_spatial为True)return x_out# 示例代码
if __name__ == '__main__':input = torch.randn(50, 512, 7, 7)  # 生成随机输入triplet = TripletAttention()  # 实例化TripletAttentionoutput = triplet(input)  # 应用TripletAttentionprint(output.shape)  # 打印输出形状

相关文章:

  • RAG数据集综述
  • 第六章 进阶19 琦琦的追求
  • Windows 文件复制利器:ROBOCOPY 拷贝命令指南
  • 全球域名WHOIS信息查询免费API接口教程
  • Tlias-web 管理系统项目知识点复盘总结
  • 高性能Tick级别高频交易引擎设计与实现
  • 6月13日day52打卡
  • 西电新增信息力学与感知学院,26考研正式招生
  • 【python深度学习】DAY 52 神经网络调参
  • 第三章支线八 ·构建之巅 · 工具链与打包炼金术
  • PHP商城源码:构建高效电商平台的利器
  • DeepSeek 引领前端开发变革:AI 助力学习与工作新路径
  • record类型-Java 16
  • 使用 PolarProxy+Proxifier 解密 TLS 流量
  • Stone 3D使用RemoteMesh组件极大的缩小工程文件尺寸
  • python实现鸟类识别系统实现方案
  • C++中 using 命名别名和命名别名模板的用法
  • 提升搜索可见度的基石:标题标签设置原则与SEO效能量化分析
  • 服务自动添加实例工具
  • 中国温室气体排放因子数据库
  • 网站流量分析软件/谷歌网页版
  • 百度域名的ip地址/seo 优化教程
  • 杭州全网推广/网站整站优化
  • ie浏览器打不开建设银行网站/职业技能培训有哪些
  • 深圳vi设计手册/seo关键词优化案例
  • 有那些做自媒体短视频的网站/营销策划书格式及范文