KV Cache原理详解 + 代码理解
基本背景:
KV Cache(Key-Value缓存)主要用于 加速自回归模型(如Transformer)的序列生成,解决以下核心问题:
-
重复计算:传统自回归生成时,每次预测新token都需要重新计算所有历史token的Key和Value,计算成本随序列长度平方级增长(O(n²))。
-
内存瓶颈:长序列生成时,反复投影历史token的特征矩阵会占用大量显存带宽。
KV Cache通过缓存历史token的中间计算结果,将复杂度降至 O(n),成为GPT、LLaMA等大模型生成文本/语音的核心优化技术。
原理讲解:
自注意力计算的公式为:
-
Q (Query):代表当前需要计算的位置(即新生成的token),每次解码时唯一变化的部分。
-
K (Key)/V (Value):代表历史token的上下文信息,需要被重复利用。
基于这个特性,我们可以考虑缓存K、V而避免重复计算增加效率。
由于 Decoder 中一般会有掩码矩阵,因此Q往往是个下三角矩阵,计算公式如下:
可以看到,结果矩阵的第 k 行只用到了矩阵 X 的 第 k 个行向量。所以 X 不需要进行全部的矩阵乘法,每一步只取第 k 个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。
在没有 KV Cache 的情况下,如果要计算第 m+1 行,需要重新计算前 m 行,但是显然这样会造成大量的重复运算,因此我们可以保存前 m 行的结果,而只计算第 m+1 行即可。
例如
在计算Att2时已经保存了Q1、Q2、V1、V2,这样在计算Att3时就可以直接使用而无需充型计算
代码实现
def decode_next_token(self,x: torch.Tensor,k_cache: torch.Tensor,v_cache: torch.Tensor,attn_mask: torch.Tensor = None,torch_sdpa: bool = True,):# Q、K、V计算q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)# KV cache拼接k_cache = torch.cat([k_cache, k], dim=1)v_cache = torch.cat([v_cache, v], dim=1)batch_size = q.shape[0]q_len = q.shape[1]kv_len = k_cache.shape[1]# Q、K、V准备q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)# 注意力计算if torch_sdpa:attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)else:attn = scaled_dot_product_attention(q, k, v, attn_mask)attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)attn = F.linear(attn, self.out_w, self.out_b)x = x + attnx = F.layer_norm(x,[self.hidden_dim],self.norm_w1,self.norm_b1,self.norm_eps1,)x = x + self.mlp.forward(x)x = F.layer_norm(x,[self.hidden_dim],self.norm_w2,self.norm_b2,self.norm_eps2,)return x, k_cache, v_cache
参考文章:
https://blog.csdn.net/weixin_43799388/article/details/142164166