注意力机制的使用说明01
多头注意力机制(MHA)使用精要
核心作用: 捕捉序列数据的全局依赖关系,让每个时间点都能关注到所有其他时间点。
关键参数 (__init__
)
embed_dim
: 特征维度 (C)。必须与输入到MHA层的数据的特征维度完全一致。num_heads
: 头的数量。embed_dim
必须能被num_heads
整除。batch_first=True
: 务必设为True
。这规定了MHA期望的输入格式为(N, L, C)
。
实现蓝图 (forward
pass)
在卷积网络(输入为 (N, C, L)
)中使用MHA,遵循以下三步即可:
格式转换 (Permute In):
x = x.permute(0, 2, 1)
目的:将
(N, C, L)
转换为MHA期望的(N, L, C)
。
应用注意力块 (Attention Block):
attn_out, _ = self.mha(x, x, x)
x = self.norm(x + attn_out)
目的:执行自注意力计算,并用残差连接和层归一化稳定训练。
格式恢复 (Permute Back):
x = x.permute(0, 2, 1)
目的:将
(N, L, C)
转换回(N, C, L)
,以适配后续的卷积层。
黄金法则: MHA的 embed_dim
参数值,必须等于你的数据在进入MHA模块时的特征维度(通道数C),而不是最原始信号的维度。
import torch
import torch.nn as nnclass AttentionBlock(nn.Module):def __init__(self, embed_dim, num_heads):super(AttentionBlock, self).__init__()# 确保 embed_dim 能被 num_heads 整除if embed_dim % num_heads != 0:raise ValueError(f"embed_dim ({embed_dim}) 必须能被 num_heads ({num_heads}) 整除。")self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)self.norm = nn.LayerNorm(embed_dim)def forward(self, x):# x 的输入格式应为 (N, C, L),这是CNN的典型输出格式N, C, L = x.shape# --- 配方第1步: 格式准备 ---# (N, C, L) -> (N, L, C)x_permuted = x.permute(0, 2, 1)# --- 配方第2步: 自注意力计算 ---attn_output, _ = self.mha(x_permuted, x_permuted, x_permuted)# --- 配方第3步: 稳定与融合 ---# 残差连接 + 层归一化x_stabilized = self.norm(x_permuted + attn_output)# --- 配方第4步: 格式恢复 ---# (N, L, C) -> (N, C, L)final_output = x_stabilized.permute(0, 2, 1)return final_output# --- 使用示例 ---
# 假设我们有一个来自CNN的输出
cnn_output = torch.randn(32, 64, 1024) # (N, C, L)# 创建并使用注意力块
attention_block = AttentionBlock(embed_dim=64, num_heads=8)
processed_output = attention_block(cnn_output)print(f"输入形状: {cnn_output.shape}")
print(f"输出形状: {processed_output.shape}") # 输出形状应与输入完全相同