当前位置: 首页 > news >正文

【LLM】讲清楚MLA原理

需要你对MHA、MQA、GQA有足够了解,相信本文能帮助你对MLA有新的认识。

本文内容都来自https://www.youtube.com/watch?v=0VLAoVGf_74,如果阅读本文出现问题,建议直接去看一遍。

        按照Deepseek设定一些参数值:输入token长度n=10,注意力头数目n_h=128,每个注意力头的隐含层维度d_h=128,transformer block层数 l =61,使用fp16存储参数。

        先来看MHA的kv-cache计算:

cachesize = 2*n*n_h*d_h*l*2 = 2*10*128*128*61*2 = 40MB

(第一个2是因为要保存K和V,第二个2是因为fp16占2bit)

        MQA和GQA的思路是通过不同注意力头之间共享参数,减少注意力头数目n_h来达到降低开销的目的。

        这样的问题是参数的共享会导致模型效果下降,毕竟原本有128个头,128份KV参数,每份KV参数都会计算出不一样的注意力分布,让模型能更好的根据所有的注意力分布去预测下一个词,而现在128份参数变成了1份,预测效果下降是必然的。

        如何解决这个问题?如何只保留1份参数,但又能计算出128个不同的注意力分布呢?

        MLA给出的答案是,只保留原本128分参数中共有的部分,而每份参数独有的部分则提取出来,不进行保存。

        这里就碰到了MLA第一个比较难理解的点,就是怎么找出128个W_K的共有部分和独有部分?(只以K为例,V也是一样的) 

        答案是不用去找,而是从一开始就用两个矩阵,分别去学习共有部分和独有部分。也就是下图中的W_DKV和W_UK,其中W_DKV学习共有部分,W_UK学习独有部分。也就是说128个注意力头,会共用W_DKV,但是每个注意力头的W_UK是独有的,这样保证了128个注意力头能计算出128个不同的注意力分布。

这里就会碰到MLA第二个比较难理解的点,为什么最后kv-cache只用保存L_KV,而不用保存K和V?

答案是根本就不存在K和V,MLA很巧妙的利用矩阵乘法,把W_UK与W_Q融合,把W_UV和W_O融合。至于为什么能这样做,可以从公式中找出答案。

说不存在W_UK和W_UV其实并不严谨,但是这样可以更方便去理解,其实这里所谓的把W_UK与W_Q融合是指输入先经过W_Q,紧跟着就经过W_UK,从结果上来看,跟先把W_UK与W_Q相乘得到W_QUK,然后输入经过W_QUK的效果是一样的。

        原本,加入W_DKV后,注意力的计算公式为:

 A = QK^T = (XW_Q)(XW_{DKV}W_{UK})^T

        按照矩阵运算,上述公式可以写成下述形式:

A = (XW_Q)(W_{UK}^TW_{DKV}^TX^T) = (XW_QW_{UK}^T)(W_{DKV}^TX^T) = (XW_QW_{UK}^T)(XW_{DKV})^T = (XW_{QUK})L_{KV}^T

        我们完全可以将W_QW_{UK}^T视作一个矩阵W_{QUK},它和W_Q并没有什么本质区别,只是维度需要调整(当然实际实现上还是两个矩阵,分开来学习)。从上式中,我们发现注意力计算公式中的K消失了

        然后是最终输出O的计算:

O = AVW_O = A(XW_{DKV}W_V)W_O = A(XW_{DKV})(W_VW_O) = AL_{KV}W_{VO}

        同理,这样就能把W_V融进W_O中,我们能够发现,最终输出的计算公式中,V也消失了

        最后的效果如下图,我们需要保存的只有L_KV,它是128个注意力头共用的,所以只需要保存一份,存储开销计算如下,整个计算公式中完全不需要考虑注意力头数目:

cachesize = n*d_h'*l*2 = 10*576*61*2 = 0.7MB

        开销降低40/0.7,约57倍,也就是deepseek技术报告中公布的压缩倍数。 

http://www.dtcms.com/a/309289.html

相关文章:

  • Linux(15)——进程间通信
  • EasyExcel 公式计算大全
  • Spring Boot Actuator 保姆级教程
  • 包裹移动识别误报率↓76%:陌讯时序建模算法实战解析
  • C#实现左侧折叠导航菜单
  • 数据结构(9)栈和队列
  • 完整的 Spring Boot + Hibernate/JPA + P6Spy 配置指南
  • 凸优化:常见的优化问题,偏统计视角
  • cesium FBO(四)自定义相机渲染到Canvas(离屏渲染)
  • android APT技术
  • 今日链表系列
  • 京东零售在智能供应链领域的前沿探索与技术实践
  • X2Doris是SelectDB可视化数据迁移工具,安装与部署使用手册,轻松进行大数据迁移
  • Blender 智能模型库 | 人物·建筑·场景·机械等 近万高精度模型
  • 无人机自动跟随模块技术分析
  • SpringMVC的高级特性
  • 机密计算与AI融合:安全与智能的共生架构
  • 《B3611 【模板】传递闭包》
  • 编程与数学 03-002 计算机网络 17_云计算与网络
  • Java 日期时间处理:分类、用途与性能分析
  • macOS卸载.net core 8.0
  • HarmonyOS】鸿蒙应用开发中常用的三方库介绍和使用示例
  • 代码随想录算法训练营第三十八天
  • NLP 和 LLM 区别、对比 和关系
  • MT Photos图库部署详解:Docker搭建+贝锐蒲公英异地组网远程访问
  • 卸油作业安全设施漏检率↓76%!陌讯多模态融合算法实战解析
  • [AI8051U入门第十二步]W5500-Modbus TCP从机
  • 浏览器【详解】内置Observer(共五种,用于前端监控、图片懒加载、无限滚动、响应式布局、生成安全报告等)
  • 算法26. 删除有序数组中的重复项
  • 宝塔网站如何禁止使用IP访问