当前位置: 首页 > news >正文

注意力机制的使用说明01

多头注意力机制(MHA)使用精要

核心作用: 捕捉序列数据的全局依赖关系,让每个时间点都能关注到所有其他时间点。

关键参数 (__init__)
  1. embed_dim: 特征维度 (C)。必须与输入到MHA层的数据的特征维度完全一致。

  2. num_heads: 头的数量embed_dim 必须能被 num_heads 整除。

  3. batch_first=True: 务必设为 True。这规定了MHA期望的输入格式为 (N, L, C)

实现蓝图 (forward pass)

在卷积网络(输入为 (N, C, L))中使用MHA,遵循以下三步即可:

  1. 格式转换 (Permute In):

    • x = x.permute(0, 2, 1)

    • 目的:将 (N, C, L) 转换为MHA期望的 (N, L, C)

  2. 应用注意力块 (Attention Block):

    • attn_out, _ = self.mha(x, x, x)

    • x = self.norm(x + attn_out)

    • 目的:执行自注意力计算,并用残差连接和层归一化稳定训练。

  3. 格式恢复 (Permute Back):

    • x = x.permute(0, 2, 1)

    • 目的:将 (N, L, C) 转换回 (N, C, L),以适配后续的卷积层。

黄金法则: MHA的 embed_dim 参数值,必须等于你的数据在进入MHA模块时的特征维度(通道数C),而不是最原始信号的维度。

import torch
import torch.nn as nnclass AttentionBlock(nn.Module):def __init__(self, embed_dim, num_heads):super(AttentionBlock, self).__init__()# 确保 embed_dim 能被 num_heads 整除if embed_dim % num_heads != 0:raise ValueError(f"embed_dim ({embed_dim}) 必须能被 num_heads ({num_heads}) 整除。")self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)self.norm = nn.LayerNorm(embed_dim)def forward(self, x):# x 的输入格式应为 (N, C, L),这是CNN的典型输出格式N, C, L = x.shape# --- 配方第1步: 格式准备 ---# (N, C, L) -> (N, L, C)x_permuted = x.permute(0, 2, 1)# --- 配方第2步: 自注意力计算 ---attn_output, _ = self.mha(x_permuted, x_permuted, x_permuted)# --- 配方第3步: 稳定与融合 ---# 残差连接 + 层归一化x_stabilized = self.norm(x_permuted + attn_output)# --- 配方第4步: 格式恢复 ---# (N, L, C) -> (N, C, L)final_output = x_stabilized.permute(0, 2, 1)return final_output# --- 使用示例 ---
# 假设我们有一个来自CNN的输出
cnn_output = torch.randn(32, 64, 1024) # (N, C, L)# 创建并使用注意力块
attention_block = AttentionBlock(embed_dim=64, num_heads=8)
processed_output = attention_block(cnn_output)print(f"输入形状: {cnn_output.shape}")
print(f"输出形状: {processed_output.shape}") # 输出形状应与输入完全相同

http://www.dtcms.com/a/298736.html

相关文章:

  • RNN模型数学推导过程(笔记)
  • 散列表(哈希表)
  • SQL基础⑮ | 触发器
  • 亚德诺半导体AD8539ARZ-REEL7 超低功耗轨到轨运算放大器,自动归零技术,专为可穿戴设备设计!
  • Python 程序设计讲义(20):选择结构程序设计——双分支结构的简化表示(三元运算符)
  • 【linux】Haproxy七层代理
  • 电子基石:硬件工程师的器件手册 (八) - 栅极驱动IC:功率器件的神经中枢
  • 【自动化运维神器Ansible】Ansible常用模块之Copy模块详解
  • 程序代码篇---卡尔曼滤波与PID的组合应用
  • 2.Linux 网络配置
  • 【PyTorch】图像多分类项目部署
  • python基础:request模块简介与安装、基本使用,如何发送get请求响应数据,response属性与请求头
  • centOS7 yum安装新版本的cmake,cmake3以上怎么安装,一篇文章说明白
  • Java并发编程第十篇(ThreadPoolExecutor线程池组件分析)
  • 无印 v1.6 视频解析去水印工具,支持多个平台
  • Android悬浮窗导致其它应用黑屏问题解决办法
  • RocketMQ 5.3.0 ARM64 架构安装部署指南
  • J2EE模式---数据访问对象模式
  • C语言案例《猜拳游戏》
  • VSCode 报错 Error: listen EACCES: permission denied 0.0.0.0:2288
  • Java 笔记 interface
  • C#入门实战:数字计算与条件判断
  • Web攻防-业务逻辑篇密码找回重定向目标响应包检验流程跳过回显泄露验证枚举
  • 【PyTorch】图像多分类项目
  • 一些常见的网络攻击方式
  • CY5-OVA科研方向,星戈瑞荧光
  • Pytest tmp_path 实战指南:测试中的临时目录管理
  • C语言————原码 补码 反码 (日渐清晰版)
  • MinIO 安装指南 - Linux ARM64
  • Linux网络管理与IP配置实验指南