当前位置: 首页 > news >正文

GQA(Grouped Query Attention):分组注意力机制的原理与实践《三》

GQA 是一种在多头注意力中共享 Key/Value,但拥有独立 Query 的结构,用于提升推理效率、减少冗余计算。

✅ GQA vs 多头注意力 (MHA)

•	MHA:每个 head 都有独立的 Q/K/V
•	GQA:每个 head 有独立 Q,但共享组内 K/V

🚀 GQA 简易 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GQAAttention(nn.Module):def __init__(self, hidden_size, num_heads, num_kv_groups=1):super().__init__()assert hidden_size % num_heads == 0self.hidden_size = hidden_sizeself.num_heads = num_headsself.head_dim = hidden_size // num_headsself.num_kv_groups = num_kv_groupsassert num_heads % num_kv_groups == 0# 每个 head 的 Q 独立self.q_proj = nn.Linear(hidden_size, hidden_size)# K 和 V 是共享的(Group-wise),因此维度为 num_kv_groups * head_dimself.k_proj = nn.Linear(hidden_size, self.head_dim * num_kv_groups)self.v_proj = nn.Linear(hidden_size, self.head_dim * num_kv_groups)self.out_proj = nn.Linear(hidden_size, hidden_size)def forward(self, x):B, T, _ = x.size()# Q: [B, T, H * D] → [B, H, T, D]q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# K/V: [B, T, G * D] → [B, G, T, D]k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)# 将 KV 扩展到每个 head(head 与 group 对应)heads_per_group = self.num_heads // self.num_kv_groupsk = k.repeat_interleave(heads_per_group, dim=1)v = v.repeat_interleave(heads_per_group, dim=1)# Attention: [B, H, T, D] x [B, H, D, T] → [B, H, T, T]attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_probs = F.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_probs, v)  # [B, H, T, D]attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size)return self.out_proj(attn_output)

🧠 参数解释

参数名 含义
hidden_size 模型总隐藏维度
num_heads Query 的数量
num_kv_groups K/V 分组数量(小于 num_heads)
heads_per_group 每组多少个 head 共享一个 KV

📌 举例:设置说明

GQAAttention(hidden_size=768, num_heads=12, num_kv_groups=4)

含义为:
• 有 12 个 Q-head(每个独立)
• 只有 4 个 K/V group(被共享)
• 每 3 个 Q-head 共享 1 个 KV group

✅ GQAAttention 测试函数(PyTorch)

def test_gqa():import torch# 参数设置batch_size = 2seq_len = 10hidden_size = 768num_heads = 12num_kv_groups = 4# 构造 GQA 模块gqa = GQAAttention(hidden_size=hidden_size, num_heads=num_heads, num_kv_groups=num_kv_groups)# 随机构造输入:[B, T, H]dummy_input = torch.randn(batch_size, seq_len, hidden_size)# 执行前向传播output = gqa(dummy_input)# 打印输出维度print("Input shape:", dummy_input.shape)print("Output shape:", output.shape)# 断言输出维度匹配输入assert output.shape == (batch_size, seq_len, hidden_size), "Output shape mismatch!"print("✅ GQA forward pass test passed.")if __name__ == "__main__":test_gqa()

输出

Input shape: torch.Size([2, 10, 768])
Output shape: torch.Size([2, 10, 768])
✅ GQA forward pass test passed.

相关文章:

  • Linux 环境下 PPP 拨号的嵌入式开发实现
  • 网络可靠性的定义与核心要素
  • 用户 xxx is not in the sudoers file.
  • FEMFAT许可分析中的关键指标
  • CentOS在vmware局域网内搭建DHCP服务器【踩坑记录】
  • html2canvas v1.0.0-alpha.12版本文本重叠问题修复
  • qt+vs Generated File下的moc_和ui_文件丢失导致 error LNK2001
  • Unity安卓平台开发,启动app并传参
  • 使用 SseEmitter 实现 Spring Boot 后端的流式传输和前端的数据接收
  • 麒麟+ARM架构安装mysql8的操作指南
  • setting up Activiti BPMN Workflow Engine with Spring Boot
  • 霍夫曼编码详解
  • 2025Mybatis最新教程(三)
  • 【向量化模型如何私有化部署】一文说清原理、流程与最佳实践
  • KTH5772游戏手柄摇杆专用 3D 霍尔位置传感器
  • JavaWeb:前后端分离开发-登录认证
  • uniapp uni-id-co errCode“:“uni-id-captcha-required“,“errMsg“:“Captcha required
  • 《Offer来了:Java面试核心知识点精讲》大纲
  • 第十一部分:进程通信
  • 通过ca证书的方式设置允许远程访问Docker服务
  • 如何搭建一个公司网站/福州搜索引擎优化公司
  • 网站建设资料 优帮云/百度下载app安装
  • 广州b2b网站建设/电商软文范例300字
  • 建设网站需要注意什么问题/长春百度推广公司
  • 爬黄山旅游攻略游览路线/吉林百度seo公司
  • wordpress 加速版/搜索引擎优化seo多少钱