注意力模块改进方法的原理及实现(MHA、MQA、GQA、MLA)
MHA、MQA、GQA
MHA — Multi-Head Attention:Q、K、V 独立投影,每个注意力头完全独立。
MQA — Multi-Query Attention:Q 独立投影,所有头共享 K/V,用于减少 KV 缓存。
GQA — Grouped Query Attention:Q 分组投影,每组共享一份 K/V,在保持部分独立性的同时降低 KV 缓存。
基础版-MHA
class MultiheadAttention_v1(nn.Module):def __init__(self, hidden_size, num_attention_heads, attention_dropout = 0.0):super().__init__()assert hidden_size % num_attention_heads == 0self.num_heads = num_attention_headsself.head_dim = hidden_size // num_attention_headsself.scaling = self.head_dim ** -0.5self.attention_dropout = attention_dropoutself.q_proj = nn.Linear(hidden_size, hidden_size)self.k_proj = nn.Linear(hidden_size, hidden_size)self.v_proj = nn.Linear(hidden_size, hidden_size)self.o_proj = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_states, attention_mask=None):batch_size, seq_len, _ = hidden_states.shape# 计算 QKV 并拆分多头q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# Scaled dot-product attentionattn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scalingif attention_mask is not None:attn_scores = attn_scores + attention_maskattn_probs = F.softmax(attn_scores, dim=-1)attn_probs = F.dropout(attn_probs, p=self.attention_dropout)# 输出attn_output = torch.matmul(attn_probs, v) # (batch, heads, seq_len, head_dim)attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)attn_output = self.o_proj(attn_output)return attn_output, attn_probs
进阶版-GQA(基于transformers库中实现修改)
# 给q、k添加上旋转位置编码
def apply_rotary_pos_emb(q,k,sin,cos):passdef eager_attention_forward(module: nn.Module,query: torch.Tensor,key: torch.Tensor,value: torch.Tensor,attention_mask: Optional[torch.Tensor],scaling: float,dropout: float = 0.0):key_states = repeat_kv(key, module.num_key_value_groups)value_states = repeat_kv(value, module.num_key_value_groups)attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scalingif attention_mask is not None:causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]attn_weights = attn_weights + causal_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)attn_output = torch.matmul(attn_weights, value_states)attn_output = attn_output.transpose(1, 2).contiguous()return attn_output, attn_weightsclass Attention(nn.Module):def __init__(self, config: Config, layer_idx: int):super().__init__() hidden_size = config.hidden_sizenum_key_value_heads = config.num_key_value_headsnum_attention_heads = config.num_attention_headsattention_dropout = config.attention_dropoutassert hidden_size // num_attention_heads == 0self.layer_idx = layer_idxself.head_dim = hidden_size // num_attention_headsself.num_key_value_groups = num_attention_heads // num_key_value_headsself.scaling = self.head_dim**-0.5self.attention_dropout = attention_dropoutself.q_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)self.o_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)def forward(self,hidden_states: torch.Tensor,position_embeddings: tuple[torch.Tensor, torch.Tensor],attention_mask: Optional[torch.Tensor]):input_shape = hidden_states.shape[:-1]hidden_shape = (*input_shape, -1, self.head_dim)query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)cos, sin = position_embeddingsquery_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)attn_output, attn_weights = eager_attention_forward(self,query_states,key_states,value_states,attention_mask,dropout=0.0 if not self.training else self.attention_dropout,scaling=self.scaling,)attn_output = attn_output.reshape(*input_shape, -1).contiguous()attn_output = self.o_proj(attn_output)return attn_output, attn_weights
MLA
MLA — Multi-Head Low-Rank Attention:Q、K、V 独立,先对输入做低秩投影得到Q、K、V隐向量,再升秩计算后续注意力,这减少了训练时的激活内存。对于推理来说有两点优化:一是不用缓存各头的K、V而是缓存中间的K、V隐向量;二是推理时可通过矩阵吸收将"投影到低秩→升秩"的两步合并为一步,提升了计算效率。但MLA的潜在空间压缩会破坏RoPE的位置敏感性,所以DeepSeek-V3中采用混合方案:将注意力头的部分维度使用无RoPE的MLA,部分维度使用传统的GQA+RoPE,在效率和位置编码兼容性间取得平衡。
transformers中的实现简化版
class DeepseekV3AttentionNoCache(nn.Module):def __init__(self, config, layer_idx: int):super().__init__()self.attention_dropout = config.attention_dropout # 注意力 dropoutself.rope_theta = config.rope_theta # RoPE 参数self.hidden_size = config.hidden_size # 输入隐藏状态维度self.qk_rope_head_dim = config.qk_rope_head_dim # Q/K 带 RoPE 的头维度self.qk_nope_head_dim = config.qk_nope_head_dim # Q/K 不带位置编码维度self.qk_head_dim = config.qk_head_dim # Q/K 总头维度(带/不带 RoPE)self.num_attention_heads = config.num_attention_heads # 总注意力头数self.q_lora_rank = config.q_lora_rank # Q 的 LoRA 秩self.kv_lora_rank = config.kv_lora_rank # K/V LoRA 秩self.v_head_dim = config.v_head_dim # V 的每头维度# Q 投影if self.q_lora_rank is None:self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.qk_head_dim, bias=False)else:self.q_a_proj = nn.Linear(self.hidden_size, self.q_lora_rank)self.q_a_layernorm = DeepseekV3RMSNorm(self.q_lora_rank)self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_attention_heads * self.qk_head_dim, bias=False)# K/V 投影self.kv_a_proj = nn.Linear(self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim)self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)self.kv_b_proj = nn.Linear(self.kv_lora_rank, self.num_attention_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False)# 输出投影self.o_proj = nn.Linear(self.num_attention_heads * self.v_head_dim, self.hidden_size)self.scaling = self.qk_head_dim ** -0.5 # 注意力缩放系数def forward(self,hidden_states: torch.Tensor, # (batch, seq_len, hidden_size)position_embeddings: tuple[torch.Tensor, torch.Tensor], # RoPE cos/sinattention_mask: Optional[torch.Tensor] = None):batch_size, seq_length = hidden_states.shape[:-1]# Q 投影if self.q_lora_rank is None:q_states = self.q_proj(hidden_states)else:q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))q_states = q_states.view(batch_size, seq_length, self.num_attention_heads, self.qk_head_dim)q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# K/V 投影kv = self.kv_a_proj(hidden_states)k_pass, k_rot = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass))k_pass = k_pass.view(batch_size, seq_length, self.num_attention_heads, self.qk_nope_head_dim + self.v_head_dim)k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# RoPE 处理cos, sin = position_embeddingsq_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot.unsqueeze(2), cos, sin)k_rot = k_rot.expand(*k_pass.shape[:-1], -1)# 拼接 Q/Kquery_states = torch.cat([q_pass, q_rot], dim=-1)key_states = torch.cat([k_pass, k_rot], dim=-1)# 计算注意力attn_output, attn_weights = eager_attention_forward(self,query_states,key_states,value_states,attention_mask,dropout=0.0 if not self.training else self.attention_dropout,scaling=self.scaling,)# 输出投影attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()attn_output = self.o_proj(attn_output)return attn_output, attn_weights