当前位置: 首页 > news >正文

KV Cache:大模型推理加速的核心机制

当 AI 模型生成文本时,它们经常会重复许多相同的计算,这会降低速度。KV Cache 是一种技术,它可以通过记住之前步骤中的重要信息来加快此过程。模型无需从头开始重新计算所有内容,而是重复使用已经计算过的内容,从而使文本生成更快、更高效。

从矩阵运算角度理解 KV Cache

让我们从最基础的注意力机制开始。标准的 self-attention 计算公式大家都很熟悉:

Attention(Q,K,V) = softmax(QK^T/√d_k)V

在实际应用中,随着上下文长度的增加,这个计算会变得非常昂贵。比如当我们有 10,000 个 token 时,QK^T 会产生一个 10,000×10,000 的巨大矩阵。在自回归生成过程中,每次预测新 token 都需要重新计算整个注意力矩阵。但仔细观察会发现,对于已经生成的 token,它们的 K 和 V 向量在每次计算中都是相同的。

以"我爱大模型"的生成过程为例:

  • 第1步:输入"我",预测"爱"
  • 第2步:输入"我爱",预测"大"
  • 第3步:输入"我爱大",预测"模"
  • 第4步:输入"我爱大模",预测"型"

在第4步计算时,"我爱大模"这四个字的 K 和 V 值与前面步骤中计算的完全相同。如果每次都重新计算,就是巨大的浪费。

从矩阵维度来看,每一步的计算实际上是在上一步的基础上增加一行一列。masking 机制确保了只有下三角部分参与计算,这意味着上一步计算的结果可以完全复用。

请添加图片描述

KV Cache 的核心思想是:缓存之前计算过的 K 和 V 向量,每次只计算新增 token 的部分

为什么不缓存 Q? 因为 Q 向量始终是当前步骤新生成的 token,没有复用价值。每次生成新 token 时,Q 是查询的量,即该值是基于每次的新 token 计算的。

具体来说:

  • 第1步:计算并缓存"我"的 K 和 V
  • 第2步:只计算"爱"的 K 和 V,与缓存的"我"组合使用
  • 第3步:只计算"大"的 K 和 V,与缓存的"我爱"组合使用
  • 依此类推…

这样,每一步的计算量从 O(n²) 降低到 O(n),其中 n 是当前序列长度。

让我们更仔细地看看这个过程:

第2步详细分析
当我们需要预测"大"时,传统方法会重新计算"我"和"爱"之间的所有注意力关系。但实际上:

  • "我"的 K 和 V 向量在第1步已经计算过
  • 我们只需要计算"爱"与"我"的关系,以及"爱"与自己的关系
  • 缓存的"我"的向量直接复用,避免重复计算

第3步时

  • "我"和"爱"的 K、V 向量都已缓存
  • 只需计算"大"与之前所有 token 的关系
  • 上一步形成的注意力权重矩阵的上三角部分完全不变
    在这里插入图片描述

KV Cache 关键点

首先需要明确 KV Cache 的几个核心特征:

只在推理阶段使用:训练时不需要 KV Cache,因为训练时所有 token 都是已知的。推理时由于是逐个生成 token,才需要这种缓存机制。

仅存在于 Decoder 中:如果你用的是 BERT 这种纯 Encoder 模型,是用不上 KV Cache 的。只有像 GPT 这样的自回归模型才需要。

KV Cache 内存计算

使用KV Cache即是使用空间换取时间,以下公式计算了当推理n个token(序列长度)所需占用的显存空间

KV Cache 内存 = 2 × 层数 × 注意力头数 × 头维度 × 序列长度 × 数据类型字节数

说明:

  • 这里的 “2” 代表 K 和 V 两个缓存矩阵。
  • 需要乘以层数是因为:每个 Transformer block 都有自己的 KV Cache,不同层的 K 和 V 值不同。

以 Llama 3 70B 为例的详细计算

模型参数:

  • 层数:80 层
  • 注意力头数:64 头
  • 头维度:128 维
  • 数据类型:FP16 (2字节)

单个 token 的 KV Cache:

2 × 80 × 64 × 128 × 1 × 2 = 2,621,440 字节 ≈ 2.5MB

单次请求假设是中等文本长度 1K tokens,推理一次需要占用的内存是:

2.5MB × 1,000 = 2,500MB ≈ 2.5GB

20个并发用户,每人 1K tokens,这时候我们需要*20,因为每个用户的 KV Cache 是独立的,无法共享:

2.5GB × 20 = 50GB

由此可以看出公司上线一个大模型,与序列长度和用户数量线性相关,这部分也是一个很大的资源消耗。因此当我们下载一个大模型时,最好别下满,需要预留好KV Cache的显存空间

从代码直观感受KV Cache性能

以GPT-2代码做示例:

import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizerdevice = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)for use_cache in (True, False):times = []for _ in range(10):  # 测试10次取平均start = time.time()model.generate(**tokenizer("什么是KV缓存?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)times.append(time.time() - start)print(f"{'使用' if use_cache else '不使用'} KV缓存: {round(np.mean(times), 3)} ± {round(np.std(times), 3)} 秒")

测试结果让我印象深刻:

  • 使用 KV Cache: 11.885 ± 0.272 秒
  • 不使用 KV Cache: 56.197 ± 1.855 秒

相关文章:

  • 八、【状态管理篇】:Pinia 在大型应用中的状态管理实践
  • mediapipe标注视频姿态关键点(基础版加进阶版)
  • SE91 找到报错的程序
  • MySQL的参数 innodb_force_recovery 详解
  • 研发中的隐形瓶颈:知识为何越来越难被留下?
  • 清理skywalking历史索引
  • C++:设计模式--工厂模式
  • 【MySQL】第11节|MySQL 8.0 主从复制原理分析与实战
  • 看fp脚本学习的知识1
  • vmvare 虚拟机内存不足
  • atomic.Value与sync.map有什么区?
  • Navicat 17 SQL 预览时表名异常右键表名,点击设计表->SQL预览->另存为的SQL预览时,表名都是 Untitled。
  • 02.【Qt开发】Qt Creator介绍及新建项目流程
  • 跳表(Skip List)查找算法详解
  • 豆包AI一键生成短视频脚本,内容创作更高效
  • 【git】 pull + rebase 或 pull + merge什么区别?
  • 没有经验能考OCP认证吗?
  • SOC-ESP32S3部分:16-I2C
  • Java基础 Day22
  • MySql(四)
  • 广西梧州为什么不能去/智能网站推广优化
  • 深圳个人网站建设/做引流的公司是正规的吗
  • 音乐网站后台模板/网络营销工具体系
  • 有没有专门做二手车网站/商业策划公司十大公司
  • 网站上截小屏幕 怎么做/湘潭网站建设
  • 武汉手机网站建设如何/贵阳网络推广外包