当前位置: 首页 > news >正文

大模型算法面试笔记——多头潜在注意力(MLA)

注意力机制解决的问题:传统序列处理模型如RNN和LSTM,捕捉长距离依赖关系的难题。注意力机制允许模型在序列的不同位置之间建立直接联系,有效捕捉远距离依赖关系。

为了减少推理过程中KV Cache占用的显存,GQA和MQA通过head之间共享KV实现,这是一种牺牲性能对存储空间妥协的方案,而MLA通过对KV对做低秩联合压缩来减少推理中的KV缓存,目标是减少kv cache存储量的同时,保存模型的效果。

具体做法是,对于每个token,先通过一个低秩矩阵将KV联合压缩到一个低维向量cKVc^{KV}cKV中,然后通过两个升维矩阵WUKW^{UK}WUKWUVW^{UV}WUV解压缩回高维,后续进行普通的多头注意力计算,这样每次只需要存这个低维向量。

这样做有个问题,就是压缩和解压操作使计算量增加了,而实际计算中,通过“矩阵吸收”操作,也就是矩阵运算过程中的结合律使多个矩阵合并,从而减少计算量。
具体计算过程如下(对qqq做相同压缩操作,于是也有了cQc^QcQWUQW^{UQ}WUQ):

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(dqkT)vWO
=softmax⁡(cQWUQ(cKVWUK)Td)cKVWUVWO=\operatorname{softmax}(\frac{c^QW^{UQ}(c^{KV}W^{UK})^T}{\sqrt{d}})c^{KV}W^{UV}W^O=softmax(dcQWUQ(cKVWUK)T)cKVWUVWO
=softmax⁡(cQ(WUQ(WUK)T)(cKV)Td)cKV(WUVWO)=\operatorname{softmax}(\frac{c^Q(W^{UQ}(W^{UK})^T)(c^{KV})^T}{\sqrt{d}})c^{KV}(W^{UV}W^O)=softmax(dcQ(WUQ(WUK)T)(cKV)T)cKV(WUVWO)
=softmax⁡(cQWUQUK(cKV)Td)cKVWUVO=\operatorname{softmax}(\frac{c^QW^{UQUK}(c^{KV})^T}{\sqrt{d}})c^{KV}W^{UVO}=softmax(dcQWUQUK(cKV)T)cKVWUVO

如上所示,计算过程中,由于矩阵乘法结合律,WUQW^{UQ}WUQWUKW^{UK}WUK合并成一个矩阵WUQUKW^{UQUK}WUQUK,同理,WUVW^{UV}WUVWOW^OWO合并成WUVOW^{UVO}WUVO

对比普通MHA计算公式:

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(dqkT)vWO
=softmax⁡(htWQKTd)vWO=\operatorname{softmax}(\frac{h_tW^QK^T}{\sqrt{d}})vW^O=softmax(dhtWQKT)vWO

可知,两种注意力机制计算量相同,没有引入额外计算量,而缓存从两个高维KKK,VVV变成了一个低维cKVc^{KV}cKV

接下来是位置编码RoPE的处理,MHA中,RoPE可以通过对qqq,kkk向量乘以一个位置相关的变换矩阵RiR_iRi(iii为当前token所处的位置)。然而,在MLA中,如果做相同的处理将会如下所示:

qiRi(kjRj)T=cQWUQRi(cjKVWUKRj)Tq_iR_i(k_jR_j)^T=c^QW^{UQ}R_i(c^{KV}_jW^{UK}R_j)^TqiRi(kjRj)T=cQWUQRi(cjKVWUKRj)T
=cQWUQRiRjT(WUK)T(cjKV)T=c^QW^{UQ}R_iR^T_j(W^{UK})^T(c_j^{KV})^T=cQWUQRiRjT(WUK)T(cjKV)T

由于RiR_iRi不是一个固定的矩阵,无法实现矩阵吸收来减少计算量。对于这个问题,deepseek的做法是将参与注意力计算的qqqvvv分成两部分,一部分进行矩阵吸收操作,不带位置信息,一部分进行位置信息计算。

对于qqq,基于潜在向量cQc^QcQ通过矩阵WQRW^{QR}WQR变换为低维向量后进行RoPE变换得到qRq^RqR;对于kkk,直接将输入hth_tht也通过一个矩阵WKRW^{KR}WKR变换后做RoPE变换得到kRk^RkR,其中kRk^RkR按照MQA的处理方式,各个head之间共享,既减少了显存调用又保证了位置编码的全局一致。然后将qRq_RqR,kRk^RkR拼接到前面计算得到的qqq,kkk向量后面,得到最终用于计算注意力的q=[qC;qR]q=[q^C;q^R]q=[qC;qR],k=[kC;kR]k=[k^C;k^R]k=[kC;kR]
这样计算点积时如下,其中ttt,jjj表示token,iii表示head:

qt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktRq_{t,i}k_{j,i}^T=[q_{t,i}^C;q_{t,i}^R]\times[k_{j,i}^C;k_t^R]=q^C_{t,i}k_{j,i}^C+q_{t,i}^Rk_t^Rqt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktR

这样不包含位置编码的部分qt,iCkj,iCq^C_{t,i}k_{j,i}^Cqt,iCkj,iC就可以进行矩阵吸收的处理,每个head缓存一个ct,iKVc_{t,i}^{KV}ct,iKV;后一项按MQA的方式计算,所有head只需缓存一个共享的ktRk^R_tktR

学习参考资料:zhihu-冷面爸

http://www.dtcms.com/a/545529.html

相关文章:

  • 常州城投建设工程招标有限公司网站泰安网站建设策划方案
  • 南通公司网站建设湖南网站优化公司
  • 做图标的网站广州海珠发布
  • 2022/12 JLPT听力原文 问题四
  • openEuler安装mysql8,流程详细
  • 【Linux】库制作与原理 从生成使用到 ELF 文件与链接原理解析
  • 【开题答辩全过程】以 儿童疫苗接种提醒系统的设计与实现为例,包含答辩的问题和答案
  • 【linux】基础开发工具(2)vim
  • 宁波找网站建设企业如何使用网络营销策略
  • 关于进一步做好网络安全等级保护有关工作的问题释疑-【二级以上系统重新备案】、【备案证明有效期三年】
  • Flink Keyed State 详解之三
  • LangChain4j学习3:模型参数
  • 驻马店做网站哪家好常州微网站建设
  • 深圳网站建设报价网站开发客户来源
  • 仓颉开发鸿蒙应用:深入理解组件生命周期的设计哲学与实践
  • Java 启动脚本-简介版
  • CFX Manager下载安装教程
  • 基于STM32HAL库判断传感器数据和系统定时器外部中断
  • 仓颉语言中的成员变量与方法:深入剖析与工程实践
  • JavaScript是如何执行的——V8引擎的执行
  • GEO:AI 时代流量新入口,四川嗨它科技如何树立行业标杆? (2025年10月最新版)
  • 【牛客刷题-剑指Offer】BM24 二叉树的中序遍历:左根右的奇妙之旅(递归+迭代双解法详解)
  • 宝山网站建设哪家好平面设计免费模板网站
  • 腾讯云 怎样建设网站免费自助建站工具
  • elasticsearch中文分词器插件下载
  • 【开题答辩全过程】以 叮叮网上图书销售管理系统为例,包含答辩的问题和答案
  • 2025—2028年教育部面47项白名单赛事汇总表(正式版)
  • IPython.display 显示网页
  • Excel怎么根据身份证号码来计算年龄?
  • 江阴网站网站建设免费的舆情网站