大模型中的KV Cache
1. KV Cache的定义与核心原理
KV Cache(Key-Value Cache)是一种在Transformer架构的大模型推理阶段使用的优化技术,通过缓存自注意力机制中的键(Key)和值(Value)矩阵,避免重复计算,从而显著提升推理效率。
原理:
-  自注意力机制:在Transformer中,注意力计算基于公式: 
 Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V = ∑ i = 1 n w i v i (加权求和形式) \begin{split} \text{Attention}(Q, K, V) &= \text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V \\ &= \sum_{i=1}^n w_i v_i \quad \text{(加权求和形式)} \end{split} Attention(Q,K,V)=softmax(dkQK⊤)V=i=1∑nwivi(加权求和形式)
 其中,Q(Query)、K(Key)、V(Value)由输入序列线性变换得到。
-  缓存机制:在生成式任务(如文本生成)中,模型以自回归方式逐个生成token。首次推理时,计算所有输入token的K和V并缓存;后续生成时,仅需为新token计算Q,并从缓存中读取历史K和V进行注意力计算。 
-  复杂度优化:传统方法的计算复杂度为O(n²),而KV Cache将后续生成的复杂度降为O(n),避免重复计算历史token的K和V。 
2. KV Cache的核心作用
-  加速推理:通过复用缓存的K和V,减少矩阵计算量,提升生成速度。例如,某聊天机器人应用响应时间从0.5秒缩短至0.2秒。 
-  降低资源消耗:显存占用减少约30%-50%(例如移动端模型从1GB降至0.6GB),支持在资源受限设备上部署大模型。 
-  支持长文本生成:缓存机制使推理耗时不再随文本长度线性增长,可稳定处理长序列(如1024 token以上)。 
-  保持模型性能:仅优化计算流程,不影响输出质量。 
3. 技术实现与优化策略
实现方式:
-  数据结构 - KV Cache以张量形式存储,Key Cache和Value Cache的形状分别为(batch_size, num_heads, seq_len, k_dim)和(batch_size, num_heads, seq_len, v_dim)。
 
- KV Cache以张量形式存储,Key Cache和Value Cache的形状分别为
-  两阶段推理: - 初始化阶段:计算初始输入的所有K和V,存入缓存。
- 迭代阶段:仅计算新token的Q,结合缓存中的K和V生成输出,并更新缓存。
 • 代码示例(Hugging Face Transformers):设置model.generate(use_cache=True)即可启用KV Cache。
 
优化策略:
-  稀疏化(Sparse):仅缓存部分重要K和V,减少显存占用。 
-  量化(Quantization):将K和V矩阵从FP32转为INT8/INT4,降低存储需求。 
共享机制(MQA/GQA):
-  Multi-Query Attention (MQA):所有注意力头共享同一组K和V,显存占用降低至1/头数。 
-  Grouped-Query Attention (GQA):将头分组,组内共享K和V,平衡性能和显存。 
4. 挑战与局限性
-  显存压力:随着序列长度增加,缓存占用显存线性增长(如1024 token占用约1GB显存),可能引发OOM(内存溢出)。 
-  冷启动问题:首次推理仍需完整计算K和V,无法完全避免初始延迟。 
5、python实现
import torch
import torch.nn as nn# 超参数
d_model = 4
n_heads = 1
seq_len = 3
batch_size = 3# 初始化参数(兼容多头形式)
Wq = nn.Linear(d_model, d_model, bias=False)
Wk = nn.Linear(d_model, d_model, bias=False)
Wv = nn.Linear(d_model, d_model, bias=False)# 生成模拟输入(整个序列一次性输入)
input_sequence = torch.randn(batch_size, seq_len, d_model)  # [B, L, D]# 初始化 KV 缓存(兼容多头格式)
kv_cache = {"keys": torch.empty(batch_size, 0, n_heads, d_model // n_heads),  # [B, T, H, D/H]"values": torch.empty(batch_size, 0, n_heads, d_model // n_heads) 
}# 因果掩码预先生成(覆盖最大序列长度)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()  # [L, L]'''
本循环是将整句话中的token一个一个输入,并更新KV缓存;
所以无需显示的因果掩码,因为因果掩码只用于计算注意力权重时,而计算注意力权重时,KV缓存中的key和value已经包含了因果掩码的信息。'''for step in range(seq_len):# 1. 获取当前时间步的输入(整个批次)current_token = input_sequence[:, step, :]  # [B, 1, D]# 2. 计算当前时间步的 Q/K/V(保持三维结构)q = Wq(current_token)  # [B, 1, D]k = Wk(current_token)  # [B, 1, D]v = Wv(current_token)  # [B, 1, D]# 3. 调整维度以兼容多头格式(关键修改点)def reshape_for_multihead(x):return x.view(batch_size, 1, n_heads, d_model // n_heads).transpose(1, 2)  # [B, H, 1, D/H]# 4. 更新 KV 缓存(增加时间步维度)kv_cache["keys"] = torch.cat([kv_cache["keys"], reshape_for_multihead(k).transpose(1, 2)  # [B, T+1, H, D/H]], dim=1)kv_cache["values"] = torch.cat([kv_cache["values"],reshape_for_multihead(v).transpose(1, 2)  # [B, T+1, H, D/H]], dim=1)# 5. 多头注意力计算(支持批量处理)q_multi = reshape_for_multihead(q)  # [B, H, 1, D/H]k_multi = kv_cache["keys"].transpose(1, 2)  # [B, H, T+1, D/H]print("q_multi shape:", q_multi.shape)print("k_multi shape:", k_multi.shape)# 6. 计算注意力分数(带因果掩码)attn_scores = torch.matmul(q_multi, k_multi.transpose(-2, -1)) / (d_model ** 0.5)print("attn_scores shape:", attn_scores.shape)# attn_scores = attn_scores.masked_fill(causal_mask[:step+1, :step+1], float('-inf'))# print("attn_scores shape:", attn_scores.shape)# 7. 注意力权重计算attn_weights = torch.softmax(attn_scores, dim=-1)  # [B, H, 1, T+1]# 8. 加权求和output = torch.matmul(attn_weights, kv_cache["values"].transpose(1, 2))  # [B, H, 1, D/H]# 9. 合并多头输出output = output.contiguous().view(batch_size, 1, d_model)  # [B, 1, D]print(f"Step {step} 输出:", output.shape)
