当前位置: 首页 > news >正文

自注意力,多头注意力,交叉注意力代码对比

自注意力、多头注意力与交叉注意力的PyTorch代码对比

1. 自注意力 (Self-Attention)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.embed_dim = embed_dim# 投影矩阵:Q/K/V共享输入维度self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.softmax = nn.Softmax(dim=-1)def forward(self, x):"""x: (batch_size, seq_len, embed_dim)"""# 1. 生成Q/K/V - 全部来自同一输入Q = self.query(x)  # (B, L, D)K = self.key(x)    # (B, L, D)V = self.value(x)  # (B, L, D)# 2. 计算注意力分数attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))# 3. 注意力权重归一化attn_weights = self.softmax(attn_scores)  # (B, L, L)# 4. 加权求和output = torch.matmul(attn_weights, V)  # (B, L, D)return output

核心特征

  • Q/K/V全部来自同一个输入序列
  • 注意力分数矩阵维度为(L, L),表示序列内部的关系
  • 输出序列长度和维度不变

2. 多头注意力 (Multi-Head Attention)

class MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads# 确保可分割assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by num_heads"# 多头投影矩阵self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)# 输出层self.fc_out = nn.Linear(embed_dim, embed_dim)self.softmax = nn.Softmax(dim=-1)def split_heads(self, x):"""分割为多头"""batch_size = x.size(0)# (B, L, D) -> (B, L, H, HD) -> (B, H, L, HD)return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)def forward(self, x):"""多头自注意力"""# 1. 生成Q/K/VQ = self.query(x)K = self.key(x)V = self.value(x)# 2. 分割为多头Q = self.split_heads(Q)  # (B, H, L, HD)K = self.split_heads(K)V = self.split_heads(V)# 3. 计算注意力分数attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))# 4. 注意力权重归一化attn_weights = self.softmax(attn_scores)  # (B, H, L, L)# 5. 加权求和attention = torch.matmul(attn_weights, V)  # (B, H, L, HD)# 6. 合并多头attention = attention.transpose(1, 2).contiguous()  # (B, L, H, HD)attention = attention.view(attention.size(0), -1, self.embed_dim)  # (B, L, D)# 7. 输出投影output = self.fc_out(attention)return output

核心特征

  • 基于自注意力扩展
  • 额外的分割(head splitting)和合并操作
  • 每个头在降维后的子空间(HD)中计算
  • 最终通过全连接层融合多头信息

3. 交叉注意力 (Cross-Attention)

class CrossAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.embed_dim = embed_dim# Query来自序列A,Key/Value来自序列Bself.query = nn.Linear(embed_dim, embed_dim)  # for sequence Aself.key = nn.Linear(embed_dim, embed_dim)   # for sequence Bself.value = nn.Linear(embed_dim, embed_dim) # for sequence Bself.softmax = nn.Softmax(dim=-1)def forward(self, x_a, x_b):"""x_a: (batch_size, len_a, embed_dim)  序列Ax_b: (batch_size, len_b, embed_dim)  序列B"""# 1. 生成Q/K/V - 来自不同输入源Q = self.query(x_a)   # 来自序列A (B, La, D)K = self.key(x_b)     # 来自序列B (B, Lb, D)V = self.value(x_b)   # 来自序列B (B, Lb, D)# 2. 计算注意力分数 (序列A到序列B的映射)attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))# 3. 注意力权重归一化attn_weights = self.softmax(attn_scores)  # (B, La, Lb)# 4. 加权求和output = torch.matmul(attn_weights, V)  # (B, La, D)return output

核心特征

  • Q来自一个序列,K/V来自另一个序列
  • 注意力矩阵维度为(La, Lb),表示序列间关系
  • 输出序列长度与查询序列相同(La),维度不变
  • 不要求两个序列长度相同

三者的核心对比

特性自注意力多头注意力交叉注意力
输入序列数量1个1个2个
Q来源自身自身序列A
K/V来源自身自身序列B
维度变换分割头+合并
注意力矩阵(L, L)(H, L, L)(La, Lb)
输出长度LLLa
主要用途序列内关系多角度特征提取序列间关系建模
计算复杂度O(L²·D)O(H·L²·HD)O(La·Lb·D)

使用场景示例

# 示例:序列长度均为5,嵌入维度128
x = torch.randn(2, 5, 128)  # batch_size=2, seq_len=5, embed_dim=128
y = torch.randn(2, 3, 128)  # 不同长度序列# 1. 自注意力
self_attn = SelfAttention(embed_dim=128)
output_self = self_attn(x)  # (2, 5, 128)# 2. 多头注意力 (8头)
multihead_attn = MultiHeadAttention(embed_dim=128, num_heads=8)
output_multi = multihead_attn(x)  # (2, 5, 128)# 3. 交叉注意力
cross_attn = CrossAttention(embed_dim=128)
output_cross = cross_attn(x, y)  # (2, 5, 128) - 保持查询序列长度

性能优化技巧

  1. 融合计算:现代PyTorch版本推荐使用优化API

    # PyTorch 1.12+ 优化实现
    output = F.scaled_dot_product_attention(Q, K, V, attn_mask=None)
    
  2. 内存优化:使用计算过程重算减少内存占用

    with torch.cuda.amp.autocast(enabled=True):output = some_attention(Q, K, V)
    
  3. 稀疏注意力:对大序列使用稀疏矩阵

    from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
    

相关文章:

  • Cisco Packet Tracer软件如何修改文件存储位置
  • 制造业数字化转型解决方案及应用
  • 【Python训练营打卡】day43 @浙大疏锦行
  • C语言获取数组长度方法大全(附带实例)
  • 共聚焦显微镜—赋能光学元件精密质控
  • 常见优化器Optimizer总结
  • 论文润色指令
  • shell:基础
  • C语言数组初始化方法大全(附带实例)
  • JAVA 集合进阶 06 - 09 Map 集合的实现类:HashMap、LinkecHashMap
  • JAVA 集合进阶 Map集合的实现类 TreeMap
  • 电子电路:空气也会形成电容吗?
  • 并发工具【上】——线程池及其操作
  • Elasticsearch的插件(Plugin)系统介绍
  • 多态(全)
  • 企业级实战之Iptables防火墙案例分析
  • 11. MySQL事务管理(上)
  • 极客大挑战 2019 EasySQL 1(万能账号密码,SQL注入,HackBar)
  • 3.spring基础入门(三)
  • 打卡day44
  • 宁波汽车网站建设/接广告的平台
  • 网页设计作品评价/全网seo优化电话
  • 上海做网站比较有名的公司有哪些/免费发广告的平台
  • 外贸网站建设哪家比较好/百度手机极速版
  • 品牌推广怎么做/智推教育seo课程
  • 领地免费网站开发/最新seo视频教程