当前位置: 首页 > news >正文

空间注意力机制

知识点:

空间注意力机制 spatial attention SA;

SA 中平均池化和最大池化的操作;

torch.max;


参考博客:通俗易懂理解通道注意力机制(CAM)与空间注意力机制(SAM)-CSDN博客

 


空间注意力机制代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SpatialAttention(nn.Module):def __init__(self,kernel_size=7):"""初始化空间注意力模块Args:kernel_size (int): 卷积核大小,通常为7x7"""super().__init__()# 确保kernel_size是奇数,以便paddingassert kernel_size % 2 ==1padding = kernel_size // 2self.sigmoid = nn.Sigmoid()# 定义7x7卷积层,输入通道为2(平均池化和最大池化的结果),输出通道为1self.conv = nn.Conv2d(in_channels=2,  # 输入通道数为2(平均池化和最大池化的结果)out_channels=1, # 输出通道数为1(生成空间注意力图)kernel_size=kernel_size,  # 卷积核大小,通常为7x7padding=padding,   # 填充,保持特征图大小不变bias=False # 不使用偏置)def forward(self, x):"""前向传播Args:x (torch.Tensor): 输入特征图 [B, C, H, W]Returns:torch.Tensor: 经过空间注意力加权后的特征图"""# 沿着通道维度进行平均池化和最大池化avg_pool = torch.mean(x, dim=1, keepdim=True) # F_avg^s [B,1,H,W]# 注意这里返回值是两个,最大值和索引,要用两个参数接max_pool,_ = torch.max(x, dim=1, keepdim=True)  # F_max^s [B,1,H,W]# 拼接平均池化和最大池化的结果pooled_features = torch.cat((avg_pool, max_pool), dim=1)  # [B,2,H,W]# 通过 7 * 7 卷积层处理spatial_attention = self.conv(pooled_features)# sigmoid激活spatial_attention = self.sigmoid(spatial_attention)return x * spatial_attentionif __name__ == '__main__':# 创建测试数据batch_size=2channels=3height=64width = 64x = torch.randn(batch_size, channels, height, width)sa=SpatialAttention(kernel_size=7)outputs=sa(x)print(f"input shape:{x.shape}")print(f"output shape:{outputs.shape}")

沿通道维度的平均池化

avg_pool = torch.mean(x, dim=1, keepdim=True) # F_avg^s [B,1,H,W]

沿通道维度的最大池化

 max_values, _ = torch.max(x, dim=1, keepdim=True)  # F_max^s [B,1,H,W]

注意这里返回是两个值,最大值索引也返回了,必须要用两个参数接!!!

vs 通道注意力机制中的池化操作

 

http://www.dtcms.com/a/241882.html

相关文章:

  • 如何判断一个bug,是前端还是后端的?
  • 积累-Vue.js 开发实用指南:ElementUI 与核心技巧
  • 【动作】动作标签分类的三大模块
  • 如何在看板中体现优先级变化
  • 【沉浸式求职学习day53】【Spring】
  • 3.3.1_1 检错编码(奇偶校验码)
  • 智能呼入系统助力酒店客服服务
  • 立足数字人文,深化历史叙事|科学智能赋能人文社科领域研究
  • 扁平表+递归拼树思想
  • PyTorch终极实战:从自定义层到模型部署全流程拆解​
  • JS深入之从原型到原型链
  • 替代爬虫!亚马逊API采集商品详情实时数据开发教程
  • 苹果签名应用掉签频繁原因排查,以及如何避免
  • 第十六章 I2C
  • python 中线程、进程、协程
  • 【动作】AVA:时空定位原子视觉动作视频数据集
  • java 数据结构-HashMap
  • 零基础玩转物联网-串口转以太网模块如何快速实现与MQTT服务器通信
  • 如何提升企微CRM系统数据的准确性?5大核心策略详解
  • opencv RGB图像转灰度图
  • 华为云Flexus+DeepSeek征文 | 华为云ModelArts Studio快速上手:DeepSeek-R1-0528商用服务的开通与使用
  • 软件定义汽车的转型之路已然开启
  • SkyWalking 10.2.0 SWCK 配置过程
  • ARM内存理解(一)
  • AutoCAD 2024 保姆级安装教程【2025最新】(附安装包)
  • 智能卷料系统仿真|从动建模到零停机优化—MapleSim卷料处理库工业级解决方案
  • 使用 VSCode 开发 FastAPI 项目(1)
  • 华为云Flexus+DeepSeek征文|体验华为云ModelArts快速搭建Dify-LLM应用开发平台并创建联网大模型
  • VScode - 我的常用插件01 - 主题插件Noctis
  • 【Css】css修改滚动条的样式