当前位置: 首页 > news >正文

注意力模块改进方法的原理及实现(MHA、MQA、GQA、MLA)

MHA、MQA、GQA

MHA — Multi-Head Attention:Q、K、V 独立投影,每个注意力头完全独立。
MQA — Multi-Query Attention:Q 独立投影,所有头共享 K/V,用于减少 KV 缓存。
GQA — Grouped Query Attention:Q 分组投影,每组共享一份 K/V,在保持部分独立性的同时降低 KV 缓存。

基础版-MHA

class MultiheadAttention_v1(nn.Module):def __init__(self, hidden_size, num_attention_heads, attention_dropout = 0.0):super().__init__()assert hidden_size % num_attention_heads == 0self.num_heads = num_attention_headsself.head_dim = hidden_size // num_attention_headsself.scaling = self.head_dim ** -0.5self.attention_dropout = attention_dropoutself.q_proj = nn.Linear(hidden_size, hidden_size)self.k_proj = nn.Linear(hidden_size, hidden_size)self.v_proj = nn.Linear(hidden_size, hidden_size)self.o_proj = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_states, attention_mask=None):batch_size, seq_len, _ = hidden_states.shape# 计算 QKV 并拆分多头q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# Scaled dot-product attentionattn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scalingif attention_mask is not None:attn_scores = attn_scores + attention_maskattn_probs = F.softmax(attn_scores, dim=-1)attn_probs = F.dropout(attn_probs, p=self.attention_dropout)# 输出attn_output = torch.matmul(attn_probs, v)  # (batch, heads, seq_len, head_dim)attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)attn_output = self.o_proj(attn_output)return attn_output, attn_probs

进阶版-GQA(基于transformers库中实现修改)

# 给q、k添加上旋转位置编码
def apply_rotary_pos_emb(q,k,sin,cos):passdef eager_attention_forward(module: nn.Module,query: torch.Tensor,key: torch.Tensor,value: torch.Tensor,attention_mask: Optional[torch.Tensor],scaling: float,dropout: float = 0.0):key_states = repeat_kv(key, module.num_key_value_groups)value_states = repeat_kv(value, module.num_key_value_groups)attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scalingif attention_mask is not None:causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]attn_weights = attn_weights + causal_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)attn_output = torch.matmul(attn_weights, value_states)attn_output = attn_output.transpose(1, 2).contiguous()return attn_output, attn_weightsclass Attention(nn.Module):def __init__(self, config: Config, layer_idx: int):super().__init__()        hidden_size = config.hidden_sizenum_key_value_heads = config.num_key_value_headsnum_attention_heads = config.num_attention_headsattention_dropout = config.attention_dropoutassert hidden_size // num_attention_heads == 0self.layer_idx = layer_idxself.head_dim = hidden_size // num_attention_headsself.num_key_value_groups = num_attention_heads // num_key_value_headsself.scaling = self.head_dim**-0.5self.attention_dropout = attention_dropoutself.q_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)self.o_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim)def forward(self,hidden_states: torch.Tensor,position_embeddings: tuple[torch.Tensor, torch.Tensor],attention_mask: Optional[torch.Tensor]):input_shape = hidden_states.shape[:-1]hidden_shape = (*input_shape, -1, self.head_dim)query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)cos, sin = position_embeddingsquery_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)attn_output, attn_weights = eager_attention_forward(self,query_states,key_states,value_states,attention_mask,dropout=0.0 if not self.training else self.attention_dropout,scaling=self.scaling,)attn_output = attn_output.reshape(*input_shape, -1).contiguous()attn_output = self.o_proj(attn_output)return attn_output, attn_weights
MLA

MLA — Multi-Head Low-Rank Attention:Q、K、V 独立,先对输入做低秩投影得到Q、K、V隐向量,再升秩计算后续注意力,这减少了训练时的激活内存。对于推理来说有两点优化:一是不用缓存各头的K、V而是缓存中间的K、V隐向量;二是推理时可通过矩阵吸收将"投影到低秩→升秩"的两步合并为一步,提升了计算效率。但MLA的潜在空间压缩会破坏RoPE的位置敏感性,所以DeepSeek-V3中采用混合方案:将注意力头的部分维度使用无RoPE的MLA,部分维度使用传统的GQA+RoPE,在效率和位置编码兼容性间取得平衡。

transformers中的实现简化版

class DeepseekV3AttentionNoCache(nn.Module):def __init__(self, config, layer_idx: int):super().__init__()self.attention_dropout = config.attention_dropout  # 注意力 dropoutself.rope_theta = config.rope_theta  # RoPE 参数self.hidden_size = config.hidden_size  # 输入隐藏状态维度self.qk_rope_head_dim = config.qk_rope_head_dim  # Q/K 带 RoPE 的头维度self.qk_nope_head_dim = config.qk_nope_head_dim  # Q/K 不带位置编码维度self.qk_head_dim = config.qk_head_dim  # Q/K 总头维度(带/不带 RoPE)self.num_attention_heads = config.num_attention_heads  # 总注意力头数self.q_lora_rank = config.q_lora_rank  # Q 的 LoRA 秩self.kv_lora_rank = config.kv_lora_rank  # K/V LoRA 秩self.v_head_dim = config.v_head_dim  # V 的每头维度# Q 投影if self.q_lora_rank is None:self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.qk_head_dim, bias=False)else:self.q_a_proj = nn.Linear(self.hidden_size, self.q_lora_rank)self.q_a_layernorm = DeepseekV3RMSNorm(self.q_lora_rank)self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_attention_heads * self.qk_head_dim, bias=False)# K/V 投影self.kv_a_proj = nn.Linear(self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim)self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)self.kv_b_proj = nn.Linear(self.kv_lora_rank, self.num_attention_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False)# 输出投影self.o_proj = nn.Linear(self.num_attention_heads * self.v_head_dim, self.hidden_size)self.scaling = self.qk_head_dim ** -0.5  # 注意力缩放系数def forward(self,hidden_states: torch.Tensor,  # (batch, seq_len, hidden_size)position_embeddings: tuple[torch.Tensor, torch.Tensor],  # RoPE cos/sinattention_mask: Optional[torch.Tensor] = None):batch_size, seq_length = hidden_states.shape[:-1]# Q 投影if self.q_lora_rank is None:q_states = self.q_proj(hidden_states)else:q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))q_states = q_states.view(batch_size, seq_length, self.num_attention_heads, self.qk_head_dim)q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# K/V 投影kv = self.kv_a_proj(hidden_states)k_pass, k_rot = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass))k_pass = k_pass.view(batch_size, seq_length, self.num_attention_heads, self.qk_nope_head_dim + self.v_head_dim)k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# RoPE 处理cos, sin = position_embeddingsq_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot.unsqueeze(2), cos, sin)k_rot = k_rot.expand(*k_pass.shape[:-1], -1)# 拼接 Q/Kquery_states = torch.cat([q_pass, q_rot], dim=-1)key_states = torch.cat([k_pass, k_rot], dim=-1)# 计算注意力attn_output, attn_weights = eager_attention_forward(self,query_states,key_states,value_states,attention_mask,dropout=0.0 if not self.training else self.attention_dropout,scaling=self.scaling,)# 输出投影attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()attn_output = self.o_proj(attn_output)return attn_output, attn_weights

文章转载自:

http://WKJjw0AF.jxcwn.cn
http://KkWZSxbF.jxcwn.cn
http://erOZ7FJ6.jxcwn.cn
http://qtb4Nqzj.jxcwn.cn
http://R4ieljBE.jxcwn.cn
http://KomTkpuy.jxcwn.cn
http://sy0eC7DG.jxcwn.cn
http://Yg93Etdl.jxcwn.cn
http://BFzEcLqN.jxcwn.cn
http://8RfSAfP8.jxcwn.cn
http://HCvLdl6N.jxcwn.cn
http://ME8ln5af.jxcwn.cn
http://SDGJLVFX.jxcwn.cn
http://yGFQIl5L.jxcwn.cn
http://3yPDJQC0.jxcwn.cn
http://8oZuSmVr.jxcwn.cn
http://59UYcy5i.jxcwn.cn
http://7NpNSmKR.jxcwn.cn
http://Fnae15pK.jxcwn.cn
http://LH7wwRVS.jxcwn.cn
http://YubrJOiC.jxcwn.cn
http://MZDkJEux.jxcwn.cn
http://AtplXnX7.jxcwn.cn
http://HQscmPwT.jxcwn.cn
http://MbbNcURi.jxcwn.cn
http://v8ptD0au.jxcwn.cn
http://3bFWcJOx.jxcwn.cn
http://mBRjESuW.jxcwn.cn
http://qMUYWo6m.jxcwn.cn
http://1DHMDztl.jxcwn.cn
http://www.dtcms.com/a/374238.html

相关文章:

  • 蚂蚁 S21 Pro 220T矿机参数详解:SHA-256算法高效算力分析
  • 大模型测试包含哪些方面
  • 基于R语言的物种气候生态位动态量化与分布特征模拟
  • NGUI--Anchor组件和 事件系统
  • 基于Django的“酒店推荐系统”设计与开发(源码+数据库+文档+PPT)
  • OpenLayers数据源集成 -- 章节一:图像图层详解
  • 深度学习架构的硬件共生论:为什么GPU决定了AI的进化方向(Transformer、SSM、Mamba、MoE、CNN是什么、对比表格)
  • AndroidWorld+mobileRL
  • langchain4j笔记篇(阳哥)
  • 精简删除WIN11.24H2企业版映像内的OneDrive安装程序方法,卸载OneDrive组件
  • spring指南学习随记(一)
  • 安装配置简易VM虚拟机(CentOS 7)
  • 虚拟机中centos简单配置
  • commons-logging
  • 【小宁学习日记6 PCB】电路原理图
  • Rust位置表达式和值表达式
  • 对比:ClickHouse/MySQL/Apache Doris
  • 2025年学英语学习机选购指南
  • 浪涌测试主要用于评估电子设备或元器件在遭受短时高强度电压 / 电流冲击(浪涌)时的耐受能力
  • ANDROID,Jetpack Compose, 贪吃蛇小游戏Demo
  • html中列表和表格的使用
  • MyBatis-Plus 深度解析:IService 接口全指南
  • iPaaS 如何帮助 CIO 减少 50% 的集成成本?
  • [运动控制]PID算法再深入--多环组合控制
  • llm的一点学习笔记
  • JVM详解(一)--JVM和Java体系结构
  • Java字符串处理:String、StringBuilder与StringBuffer
  • SQL 注入与防御-第十章:确认并从 SQL 注入攻击中恢复
  • MCP(模型上下文协议)入门教程1
  • 已知两个平面点的坐标、切线方向、曲率,构造三阶Bezier曲线的方法