当前位置: 首页 > news >正文

KV Cache原理详解 + 代码理解

基本背景:

        KV Cache(Key-Value缓存)主要用于 加速自回归模型(如Transformer)的序列生成,解决以下核心问题:

  • 重复计算:传统自回归生成时,每次预测新token都需要重新计算所有历史token的Key和Value,计算成本随序列长度平方级增长(O(n²))。

  • 内存瓶颈:长序列生成时,反复投影历史token的特征矩阵会占用大量显存带宽。

        KV Cache通过缓存历史token的中间计算结果,将复杂度降至 O(n),成为GPT、LLaMA等大模型生成文本/语音的核心优化技术。

原理讲解:

        自注意力计算的公式为:

  • Q (Query):代表当前需要计算的位置(即新生成的token),每次解码时唯一变化的部分。

  • K (Key)/V (Value):代表历史token的上下文信息,需要被重复利用。

       基于这个特性,我们可以考虑缓存K、V而避免重复计算增加效率。

        由于 Decoder 中一般会有掩码矩阵,因此Q往往是个下三角矩阵,QK^{T}计算公式如下:

        

        可以看到,结果矩阵的第 k 行只用到了矩阵 X  的 第 k  个行向量。所以 X 不需要进行全部的矩阵乘法,每一步只取第 k 个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。
        在没有 KV Cache 的情况下,如果要计算第 m+1 行,需要重新计算前 m 行,但是显然这样会造成大量的重复运算,因此我们可以保存前 m 行的结果,而只计算第 m+1 行即可。

        例如

        在计算Att2时已经保存了Q1、Q2、V1、V2,这样在计算Att3时就可以直接使用而无需充型计算

        

代码实现

def decode_next_token(self,x: torch.Tensor,k_cache: torch.Tensor,v_cache: torch.Tensor,attn_mask: torch.Tensor = None,torch_sdpa: bool = True,):# Q、K、V计算q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)# KV cache拼接k_cache = torch.cat([k_cache, k], dim=1)v_cache = torch.cat([v_cache, v], dim=1)batch_size = q.shape[0]q_len = q.shape[1]kv_len = k_cache.shape[1]# Q、K、V准备q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)# 注意力计算if torch_sdpa:attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)else:attn = scaled_dot_product_attention(q, k, v, attn_mask)attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)attn = F.linear(attn, self.out_w, self.out_b)x = x + attnx = F.layer_norm(x,[self.hidden_dim],self.norm_w1,self.norm_b1,self.norm_eps1,)x = x + self.mlp.forward(x)x = F.layer_norm(x,[self.hidden_dim],self.norm_w2,self.norm_b2,self.norm_eps2,)return x, k_cache, v_cache

参考文章:

        https://blog.csdn.net/weixin_43799388/article/details/142164166

http://www.dtcms.com/a/272555.html

相关文章:

  • 从零实现一个GPT 【React + Express】--- 【2】实现对话流和停止生成
  • Pytest之收集用例规则与运行指定用例
  • 外贸网站模板 网页设计模板网站
  • WinUI3入门17:本地文件存储LocalApplicationData在哪里
  • 【佳易王桌球棋牌计时计费软件】:从功能到实操的全方位解析,灯控器适配、会员管理多场景,软件程序操作教程详解
  • BatchNorm解决梯度消失/爆炸
  • van-tabs 自定义
  • 08-自然壁纸实战教程-视频列表-云
  • softmax公式推导
  • 深度学习中的批处理vs小批量训练
  • 大数据时代UI前端的智能化升级:基于机器学习的用户意图预测
  • MyBatis-Plus的LambdaQuery用法
  • 【音视频】HTTP协议介绍
  • 钉钉拿飞书当靶
  • 测试开发和后端开发到底怎么选?
  • 打破技术债困境:从“保持现状”到成为变革的推动者
  • VILA-M3: Enhancing Vision-Language Models with Medical Expert Knowledge
  • AI大模型平台
  • 【网络】Linux 内核优化实战 - net.ipv4.tcp_keepalive_time
  • 在虚拟机中安装Linux系统
  • EasyCVR视频汇聚平台国标接入设备TCP主动播放失败排查指南
  • 操作系统-IO多路复用
  • 深度学习核心:从基础到前沿的全面解析
  • 约束-1-约束
  • 【论文笔记】A Deep Reinforcement Learning Based Real-Time Solution Policy for the TSP
  • leetcode 226 翻转二叉树
  • openEuler 24.03 (LTS-SP1) 下安装 K8s 集群 + KubeSphere 遇到 etcd 报错的解决方案
  • Qt:按像素切割图片
  • 制胶学习分享
  • FFmpeg在Go、Python、C++、Rust实践案例