GQA(Grouped Query Attention):分组注意力机制的原理与实践《三》
GQA 是一种在多头注意力中共享 Key/Value,但拥有独立 Query 的结构,用于提升推理效率、减少冗余计算。
✅ GQA vs 多头注意力 (MHA)
• MHA:每个 head 都有独立的 Q/K/V
• GQA:每个 head 有独立 Q,但共享组内 K/V
🚀 GQA 简易 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass GQAAttention(nn.Module):def __init__(self, hidden_size, num_heads, num_kv_groups=1):super().__init__()assert hidden_size % num_heads == 0self.hidden_size = hidden_sizeself.num_heads = num_headsself.head_dim = hidden_size // num_headsself.num_kv_groups = num_kv_groupsassert num_heads % num_kv_groups == 0# 每个 head 的 Q 独立self.q_proj = nn.Linear(hidden_size, hidden_size)# K 和 V 是共享的(Group-wise),因此维度为 num_kv_groups * head_dimself.k_proj = nn.Linear(hidden_size, self.head_dim * num_kv_groups)self.v_proj = nn.Linear(hidden_size, self.head_dim * num_kv_groups)self.out_proj = nn.Linear(hidden_size, hidden_size)def forward(self, x):B, T, _ = x.size()# Q: [B, T, H * D] → [B, H, T, D]q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# K/V: [B, T, G * D] → [B, G, T, D]k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)# 将 KV 扩展到每个 head(head 与 group 对应)heads_per_group = self.num_heads // self.num_kv_groupsk = k.repeat_interleave(heads_per_group, dim=1)v = v.repeat_interleave(heads_per_group, dim=1)# Attention: [B, H, T, D] x [B, H, D, T] → [B, H, T, T]attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_probs = F.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_probs, v) # [B, H, T, D]attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size)return self.out_proj(attn_output)
🧠 参数解释
参数名 含义
hidden_size 模型总隐藏维度
num_heads Query 的数量
num_kv_groups K/V 分组数量(小于 num_heads)
heads_per_group 每组多少个 head 共享一个 KV
📌 举例:设置说明
GQAAttention(hidden_size=768, num_heads=12, num_kv_groups=4)
含义为:
• 有 12 个 Q-head(每个独立)
• 只有 4 个 K/V group(被共享)
• 每 3 个 Q-head 共享 1 个 KV group
✅ GQAAttention 测试函数(PyTorch)
def test_gqa():import torch# 参数设置batch_size = 2seq_len = 10hidden_size = 768num_heads = 12num_kv_groups = 4# 构造 GQA 模块gqa = GQAAttention(hidden_size=hidden_size, num_heads=num_heads, num_kv_groups=num_kv_groups)# 随机构造输入:[B, T, H]dummy_input = torch.randn(batch_size, seq_len, hidden_size)# 执行前向传播output = gqa(dummy_input)# 打印输出维度print("Input shape:", dummy_input.shape)print("Output shape:", output.shape)# 断言输出维度匹配输入assert output.shape == (batch_size, seq_len, hidden_size), "Output shape mismatch!"print("✅ GQA forward pass test passed.")if __name__ == "__main__":test_gqa()
输出
Input shape: torch.Size([2, 10, 768])
Output shape: torch.Size([2, 10, 768])
✅ GQA forward pass test passed.