当前位置: 首页 > news >正文

LLM模型kv cache的估计和应用

transformer模型加速生成式推理的常用策略是KV cache。具体为:

1)预填充阶段,依据prompt为每个transformer层生成key cache和value cache,即kv cache;

2)解码阶段,使用并更新kv cache,依次生成输出,当前生成的token依赖之前生成的token。

这里通过示例第i个transformer的运行过程,分析为什么是kv cache,而不是kq或vq cache。然后,以此为基准,分析和估计kv cache的显存占用量。

1 transformer架构

生成式LLM的transformer架构如下图所示。左边encoder,右边decoder,采用cross attention。

decoder的multi head attention的K和V来自于encoder,而Q则来自于decoder。

KV cache计算过程设计到transformer的权重矩阵,具体如下。

第i个transformer层的权重矩阵

W_Q^i, W_K^i, W_V^i, W_O^i, W_1^i, W_2^i

其中,self-attention块的4个权重矩阵

W_Q^i, W_K^i, W_V^i, W_O^i \in \mathbb{R}^{h \times h}

mlp块的2个权重矩阵

W_1^i \in \mathbb{R}^{h \times 4h}, W_2^i \in \mathbb{R}^{4h \times h}

2 预填充阶段

预填充阶段,即直接基于prompt输入计算第i个transformer层的kv cache的过程。

假设第i个transformer层的输入为x^i,self-attention块的query,key,value和output表示为x_Q^i, x_K^i, x_V^i, x_O^i \in \mathbb{R}^{b \times s \times h}

针对xi,key cache和value cache的计算过程为

x_K^i = x_i \cdot W_K^i

x_V^i = x_i \cdot W_V^i

第i个transformer层剩余计算过程为

x_Q^i = x^i \cdot W_Q^i

x_{out}^i = softmax(\frac{x_Q^i {x_K^i}^T}{\sqrt{h}}) \cdot x_V^i W_o^i + x_i

x_{out}^{i+1} = f_{gelu}(x_{out}^i \cdot W_1) \cdot W_2 + x_{out}^i

3 解码阶段

3.1 cross attention

结合上文提到的transformer架构,在decoder阶段,key vector和value vector来自于对encoder输出的拼接可以累积,而xQ则来自于decoder的实时在线计算,这也是为什么是KV cache,而不是QKV或QV cache的原因。示例图如下。

示例decoder cross attention代码如下所示。

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.ffn = FeedForward(d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(0.1)def forward(self, x, enc_output, src_mask, tgt_mask):# Self attention (маскированное)attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))# Cross attention (с выходом энкодера)# enc_output, enc_output 对应key vector和value vectorattn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))# Feed forwardffn_output = self.ffn(x)x = self.norm3(x + self.dropout(ffn_output))return x

3.2 kv cache

给定当前生成词在第i个transformer层的向量表示为t^i \in \mathbb{R}^{b \times 1 \times h}

推断计算分两部分:更新KV cache和计算第i个transformer层的输出。

更新key cache和value cache的计算过程如下,对于xk和xv,可以通过cache 缓存之前已经计算好的结果。

x_K^{i} = Concat(x_K^i, t_i \cdot W_K^i)

x_V^{i} = Concat(x_V^i, t_i \cdot W_V^i)

第i个transformer层剩余计算过程为:

t_Q^{i} = t_i \cdot W_Q^i

t_{out}^i = softmax(\frac{t_Q^i {x_K^i}^T}{\sqrt{h}}) \cdot x_V^i \cdot W_O^i + t^i

t^{i+1} = f_{gelu}(t_{out}^i \cdot W_1) \cdot W_2 + t_{out}^i

4 KV cache估计

假设输入序列长度为s,输出序列长度为n,以float16精度保存KV cache。

参考上述分析过程,KV cache的峰值显存占用大小表示如下。

b \times (s+n) \times h \times l \times 2 \times 2 = 4blh(s+n)

第一个2表示K cache和V的cache,第二个2表示float16占2个bytes。

以GPT3为例,对比KV cache与模型参数占用显存的大小。

GPT3模型占用显存大小为350GB。

假设批次大小b=64 ,输入序列长度s=512 ,输出序列长度 n=32 。

参考上说估计公式,KV cache占用显存 4blh(s+n)=164282499072 bytes ≈164 GB,大约是模型参数的1/2。

reference

---

LLM模型的计算量估计

https://blog.csdn.net/liliang199/article/details/152081156

LLM模型的参数量估计

https://blog.csdn.net/liliang199/article/details/151839842

分析transformer模型的参数量、计算量、中间激活、KV cache

https://zhuanlan.zhihu.com/p/624740065

LLM模型的计算量与参数量的关系

https://blog.csdn.net/liliang199/article/details/152095274

LLM模型的中间激活值估计

https://blog.csdn.net/liliang199/article/details/152140815

Why KVcache, not Qcache?

https://zhuanlan.zhihu.com/p/4590995054

transformer-cross-attention

https://www.praudyog.com/deep-learning-tutorials/transformers-cross-attention/

Encoder-Decoder Cross-Attention

https://apxml.com/courses/foundations-transformers-architecture/chapter-5-encoder-decoder-stacks/encoder-decoder-cross-attention

http://www.dtcms.com/a/424564.html

相关文章:

  • 铁岭市网站建设公司重庆森林经典台词图片
  • 【开题答辩全过程】以 PHP茶叶同城配送网站的设计与实现为例,包含答辩的问题和答案
  • 自建站服务专业建网站设计公司
  • 做网站需要哪一些内容益阳建站网站制作
  • 前端开发工具都有哪些?常用前端开发工具清单与场景化推荐
  • Apifox新版本重磅升级!AI用例生成+多模块优化!
  • 深圳网站建设公司jm3q新网站备案查询
  • 智能锁网站建设关键词宜春专业的企业网站建设公司
  • 深入理解 SPI:从定义到 Spring Boot 实践
  • 麒麟区住房和城乡建设局网站桂林北站是高铁站吗
  • 彩虹表(还原函数)
  • 查表型状态机
  • 可控可信的工业界 Agent 方案研究 - parlant
  • 徐州设计网站长沙计算机培训机构排名前十
  • flink api-datastream api-sink算子
  • 有没有专门做衣服搭配的网站怎样在织梦后台里面做网站地图
  • 【go】普通map和sync.map的区别,源码解析
  • wordpress多站点详细设置(图解)建个个人网站一年多少钱
  • Python bisect
  • Docker 安装与核心知识总结
  • 编辑网站化妆品网页设计素材
  • 做视频网站的技能可以自己制作广告的软件
  • Jupyter Notebook下载安装使用教程(附安装包,图文并茂)
  • 《算法与数据结构》第七章[算法2]:广度优先搜索(BFS)
  • Salesforce 知识点:Connected App
  • 通用系统资源监控命令(Linux)
  • 衡水网站建设知识企业站系统
  • 做房产网站用什么软件亚马逊雨林的资料
  • airsim多无人机+无人车联合仿真辅导
  • 深度学习:池化(Pooling)