当前位置: 首页 > news >正文

Attention:MHA->MQA->GQA->MLA

Transformer 的注意力机制经历了从 MHA(多头注意力)MQA(多查询注意力)GQA(分组查询注意力),再到 MLA(多头潜变量注意力) 的逐步演进。这一过程的核心目标是:减少计算和显存开销,同时保持模型性能。

MHA(Multi-Head Attention,多头注意力)

MHA 是最早出现在 Transformer(Vaswani et al., 2017) 中的注意力形式。它通过 多组独立的注意力头(heads) 来并行捕捉不同子空间的关系。

数学形式:

  • 输入向量 X \in \mathbb{R}^{n \times d} ,经过线性变换得到:Q_i = XW_i^Q, \quad K_i = XW_i^K, \quad V_i = XW_i^V, \quad i=1,\dots,h

  • 对每个 head:\text{head}_i = \text{Softmax}\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right)V_i

  • 最后拼接:\text{MHA}(X) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

特点:

  • 每个注意力头都有自己独立的 W^Q, W^K, W^V,多个头可以同时计算,提高计算效率,但显存占用和计算量较大
  • 模型表达力强,能够捕获复杂的上下文关系,但参数多,计算开销大
  • 随着模型规模扩大,MHA 的参数和显存开销呈线性增长,尤其是 Key 和 Value 的存储成为瓶颈
import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.qkv = nn.Linear(embed_dim, 3 * embed_dim)self.proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):B, T, C = x.shapeqkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # [B, num_heads, T, head_dim]attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = torch.softmax(attn, dim=-1)out = (attn @ v).transpose(1, 2).reshape(B, T, C)return self.proj(out)# 使用示例
mha = MultiHeadAttention(embed_dim=512, num_heads=8)
x = torch.randn(1, 10, 512)  # [batch, seq_len, dim]
print(mha(x).shape)  # [1, 10, 512]

MQA(Multi-Query Attention,多查询注意力)

在传统的多头注意力机制中,每个注意力头都使用自己的一组查询、键和值,这可能需要大量计算,尤其是在注意力头数量增加的情况下。

多查询注意力机制 (MQA) 是 Transformer 中使用的传统多头自注意力机制(MHA)的一种变体。MQA 通过在多个注意力头之间共享同一组键和值,同时为每个注意力头维护不同的查询。

即:在 解码(inference) 阶段,MHA 的计算瓶颈主要在于存储每个 head 的 Key/Value 缓存。MQA 的改进是:多个 Query heads 共享同一个 Key 和 Value

核心思想:为了解决推理时 Key/Value 缓存过大的问题,所有头共享同一组 Key 和 Value

  • Query:每个头独立
  • Key / Value:所有头共享一组

特点:

  • Q 独立,K,V 全部共享
  • 大幅减少 KV 缓存,推理速度更快,显存占用更低,KV 缓存减少约 h 倍 (h是头数)
  • 每个头看到的 Key/Value 相同 → 表达能力略有下降,即共享 K 和 V 可能导致模型捕捉上下文的能力下降
class MultiQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.q = nn.Linear(embed_dim, embed_dim)  # 独立 Qself.k = nn.Linear(embed_dim, self.head_dim)  # 共享 Kself.v = nn.Linear(embed_dim, self.head_dim)  # 共享 Vself.proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):B, T, C = x.shapeq = self.q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, T, D]k = self.k(x).unsqueeze(1)  # [B, 1, T, D] -> 广播到所有头v = self.v(x).unsqueeze(1)attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = torch.softmax(attn, dim=-1)out = (attn @ v).transpose(1, 2).reshape(B, T, C)return self.proj(out)# 使用示例
mqa = MultiQueryAttention(embed_dim=512, num_heads=8)
print(mqa(x).shape)  # [1, 10, 512]

GQA(Grouped Query Attention,分组查询注意力)

组查询注意力 (GQA) 是对 Transformer 中使用的传统多头自注意力机制和多查询注意力机制的折中。在标准多头自注意力中,每个注意力头独立处理整个序列。这种方法虽然功能强大,但计算成本高昂,尤其是对于长序列。而MQA虽然通过在多个注意力头之间共享同一组键和值简化了这一过程,但其简化也不可避免的带来了一些精度的损失。GQA 通过将查询分组在一起来解决此问题,从而降低了计算复杂性,而不会显著影响性能。

核心思想:GQA 是 MHA 和 MQA 的折中方案:将多个 Query 头划分为若干组,每组共享一组 Key/Value,Q 独立

  • 每组包含多个 Query heads
  • 每组有独立的 Key 和 Value
  • 介于“每头独立”和“全部共享”之间

特点:

  • 减少显存,KV Cache 减少到 g/h同时保留了部分多样性,性能接近 MHA
  • 需要合理设置组数 g,组数过少可能接近 MQA,过多则接近 MHA
  • 被广泛采用(PaLM 2、Gemini、LLaMA 2、Mixtral 等)
class GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups):super().__init__()self.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_headsassert num_heads % num_groups == 0, "头数必须能被组数整除"self.q = nn.Linear(embed_dim, embed_dim)self.k = nn.Linear(embed_dim, self.head_dim * num_groups)  # 每组一个 Kself.v = nn.Linear(embed_dim, self.head_dim * num_groups)  # 每组一个 Vself.proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):B, T, C = x.shapeq = self.q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, T, D]k = self.k(x).reshape(B, T, self.num_groups, self.head_dim).transpose(1, 2)  # [B, G, T, D]v = self.v(x).reshape(B, T, self.num_groups, self.head_dim).transpose(1, 2)# 将 K/V 广播到每个组内的头k = k.repeat_interleave(self.num_heads // self.num_groups, dim=1)v = v.repeat_interleave(self.num_heads // self.num_groups, dim=1)attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = torch.softmax(attn, dim=-1)out = (attn @ v).transpose(1, 2).reshape(B, T, C)return self.proj(out)# 使用示例(4 组,8 头)
gqa = GroupedQueryAttention(embed_dim=512, num_heads=8, num_groups=4)
print(gqa(x).shape)  # [1, 10, 512]

MLA(Multi-Head Latent Attention,多头潜变量注意力)

多头潜在注意力 (MLA) 将潜在特征表示纳入注意力机制,以降低计算复杂度并改善上下文表示。MLA的核心是对KV进行压缩后,再送入标准的MHA算法中,用一个更短的k,v向量来进行计算,进而减少KV Cache的大小。

核心思想:在 GQA 的基础上进一步优化:不再直接存储 KV,而是引入一个低维“潜空间”(latent space)生成 KV,从而减少 KV Cache 的大小

工作机制:

  1. 将输入 token 投影到一个潜向量空间(通常维度更低)
  2. Key/Value 通过该潜向量生成
  3. 每个注意力头在潜空间中计算
  4. 减少 KV 缓存存储,同时保持多头的表达多样性

特点:

  • 显著减少 KV 缓存,减少 93.3%,适合超长序列推理
  • 推理更快,尤其在长上下文时
  • 性能与 GQA 相当甚至更优
  • GQA 是“多个头共享同一组 KV”,MLA 则是“多个头共享一个低维潜空间,从该空间动态生成 KV”
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadLocalAttention(nn.Module):def __init__(self, embed_dim, num_heads, window_size=4):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.window_size = window_sizeself.qkv = nn.Linear(embed_dim, 3 * embed_dim)self.proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):B, T, C = x.shapeqkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # [B, H, T, D]# 划分局部窗口x = x.view(B, T, C)x = x.unfold(1, self.window_size, self.window_size)  # [B, num_windows, window_size, C]# 每个窗口内计算注意力local_attn_outputs = []for i in range(x.size(1)):window = x[:, i, :, :]  # [B, window_size, C]q_window = q[:, :, i*self.window_size:(i+1)*self.window_size, :]k_window = k[:, :, i*self.window_size:(i+1)*self.window_size, :]v_window = v[:, :, i*self.window_size:(i+1)*self.window_size, :]attn = (q_window @ k_window.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = torch.softmax(attn, dim=-1)out_window = (attn @ v_window).transpose(1, 2).reshape(B, self.window_size, C)local_attn_outputs.append(out_window)# 合并窗口结果out = torch.cat(local_attn_outputs, dim=1)return self.proj(out)# 使用示例
mla = MultiHeadLocalAttention(embed_dim=512, num_heads=8, window_size=4)
x = torch.randn(1, 20, 512)  # [batch, seq_len, dim]
print(mla(x).shape)  # [1, 20, 512]

这篇文章也写的挺好的,可以参考看看:https://lengm.cn/post/20250226_attention/ 

style="display: none !important;">

http://www.dtcms.com/a/511741.html

相关文章:

  • 拥塞控制原理
  • Flink Kafka 生产者原理与实现
  • 路由器和机顶盒的射频核心:深入解析PA、LNA、PHY与滤波器
  • Java----set
  • python编程网站推荐郑州云帆网站设计
  • 如何做论文网站给我一个用c 做的网站
  • 青岛网站排名公司自己的网站如何让百度收录
  • MQTT主题架构的艺术:从字符串拼接走向设计模式
  • i.MAX6ULL Linux LED 字符设备驱动代码分析
  • Linux中基数树的初始化
  • 4.3 二维数组
  • 【C语言实战(40)】C语言查找算法:从基础到实战的效率进阶
  • 洛谷 P2949 [USACO09OPEN] Work Scheduling G
  • 建站公司杭州南宁制作网站服务商
  • Deepseek-ocr论文精读
  • 【完整源码+数据集+部署教程】【文件&发票】发票信息提取系统源码&数据集全套:改进yolo11-ContextGuided
  • SpringBoot+Shiro+mybatis教务管理系统源码
  • 佛山个人制作网站公司手机百度下载免费安装
  • Git 项目开发核心指南:聚焦常用语法与完整流程
  • 【图像处理基石】遥感多光谱图像处理入门:从概念到实战(附Python代码)
  • Spring Boot项目中使用线程池并发插入6万条数据的线程池参数设置指南
  • 网站建设网站设计哪家专业东莞展馆设计公司
  • Docker Swarm:打造高效、可扩展的容器编排引擎,引领微服务新纪元(上)
  • 第15章:Spring AI Alibaba — 认识Graph框架
  • [Dify 实战] 构建一个自动发送邮件的插件:从 OpenAPI 到自动化通知
  • 基于Chrome140的FB账号自动化(关键词浏览)——脚本撰写(二)
  • CICD实战(8) - 使用Arbess+GitLab实现React.js项目自动化部署
  • 小程序uview actionSheet 内容过多高度设置
  • 基于.net的个人网站开发实录哪个网站建站比较好
  • 徐州做网站公司哪家好湘建网