kvcache比赛记录
一些简单的记录
文章目录
- test.py注释版
- 文档处理
- 预计算独立的KVCache
- 目标:模拟“预埋”的KVCache
- 例子 walkthrough
- 为什么这个KVCache是“有瑕疵”的?
- 开始推理
- 核心原理:`llm.generate` 触发的“幕后”连锁反应
- 例子 Walkthrough:深入模型内部
- **进入第 0 层 (`layer_ind = 0`)**
- **进入第 1 层 (`layer_ind = 1`)**
- choose_recompute函数分析
- 1\. 原理讲解:核心思想是什么?
- 2\. 功能讲解:代码是如何工作的?
- 3\. 举例说明:模拟执行过程
所有的精度指标(F1, Precision, Recall)都是通过将 Model Answer 与这个 Dataset Answer进行对比计算出来的
test.py注释版
# 导入所有必要的库
from vllm import LLM, SamplingParams # vLLM的核心库,用于加载模型和进行推理
from utils import load_dataset, build_qa_prompt, scorer_all, extract_after_think # 从utils.py导入辅助函数
from transformers import AutoTokenizer # Hugging Face的库,用于加载分词器
import torch # PyTorch库,用于张量操作
import os # 用于处理文件路径
import sys # 系统库
import numpy as np # NumPy库,用于数值计算,如此处的平均值
import json # 用于处理JSON数据格式# --- 1. 环境与路径配置 ---
# 原理讲解:
# 这部分代码设置了所有必要的路径和参数,是脚本运行的基础。
# 您需要根据自己的环境修改这些路径,特别是`base_dir`和`model_path`。base_dir = '/home/mzq/massive_storage' # 设置一个基础目录,用于存放最终生成的JSON结果文件
num_runs = 1 # 设置要运行的数据集样本数量,这里设为1表示只跑数据集中的第一个问题
tp = 1 # Tensor Parallelism size,张量并行大小,设为1表示使用单GPU
max_tokens = 2048 # 设置模型生成的最大token数
dataset_name = 'just_for_test' # 指定要使用的数据集文件名(不含.json后缀)
model = "DeepSeek-R1-Distill-Qwen-14B" # 指定要加载的模型名称
model_path = f"/data/mzq/massive_storage/{model}" # 拼接成完整的模型存放路径# --- 2. 数据与模型加载 ---
# 原理讲解:
# 这部分代码负责将模型、分词器和数据集加载到内存中。
# 同时,它还定义了用于构建完整请求的prompt模板。# 拼接出最终结果JSON文件的完整路径
json_path = os.path.join(base_dir, f'{model}-{dataset_name}-test.json')
# 加载评测数据集
eval_dataset = load_dataset(f"/home/mzq/massive_storage/vllm-ascend-dev/data/{dataset_name}.json")# 定义用于构建prompt的模板字符串
if dataset_name == 'just_for_test':# 这是系统提示(System Prompt),告诉模型它的角色和任务规则prefix_prompt = "你是一个有帮助且知识渊博的助手。你将得到一个问题和一组从知识库中检索到的文档。请仅使用提供的上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请如实地说明。需遵从下面的指令:1、你将得到一个用户问题和一组检索到的文档。2、仅使用提供的上下文来回答问题。3、如果问题无法在上下文中找到答案,请回答:“上下文没有提供足够的信息来回答这个问题。”4、简洁并事实性回答。\n文章:\n"# 这是查询提示,放在所有文档之后,引出用户的问题query_prompt = "请基于上述文章回答下面的问题。\n问题:"# 初始化vLLM引擎
llm = LLM(model=model_path, # 模型路径max_model_len=20000, # 模型能处理的最大序列长度tensor_parallel_size=tp, # 张量并行大小enforce_eager=True, # 强制使用Eager模式,便于调试和获取中间状态enable_chunked_prefill=False, # 禁用分块预填充dtype='bfloat16') # 使用bfloat16数据类型以节省显存# 加载模型对应的分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 加载模型的配置文件(config.json),以获取模型层数等信息
model_config = load_dataset(f'{model_path}/config.json')
num_layer = model_config["num_hidden_layers"] # 获取模型的总层数# --- 3. 主循环与评测逻辑 ---
# 原理讲解:
# 这是脚本的核心。它遍历数据集中的每一个问题,并执行两个关键的推理步骤:
# 1. 预计算KVCache:模拟“预埋”操作,获取每个文档块独立的、无交叉注意力的KVCache。
# 2. 正式推理:使用选择性重算策略,对完整的长文本进行推理,并记录性能和精度。# 初始化用于存储所有问题结果的列表
ttft_blend = [] # 存储每个问题的TTFT
answers = [] # 存储标准答案
result_w_caches = [] # 存储模型答案
t_df1 = [] # 存储每个问题的F1分数
t_dpr = [] # 存储每个问题的Precision
t_drecall = [] # 存储每个问题的Recall# 打开结果文件,准备以追加模式('a')写入
with open(json_path, mode='a', newline='', encoding='utf-8') as file:file.write('[\n') # 手动写入JSON数组的开头# 遍历数据集中的每一个样本(问题)for ii, ex in enumerate(eval_dataset):dict_obj = {} # 用于存储当前问题结果的字典if ii == num_runs: # 如果处理的问题数量达到了设定的num_runs,则跳出循环breakanswer = ex["answers"] # 获取标准答案列表question = ex["question"] # 获取问题字符串# 使用utils.py中的函数构建prompt列表doc_prompts, q_prompt = build_qa_prompt(ex, query_prompt)# 组合成最终的文档块列表,结构为:[头部prompt, 文档1, 文档2, ..., 尾部prompt]doc_list = [prefix_prompt] + doc_prompts + [q_prompt]# --- 3.1 预计算独立的KVCache ---# 原理讲解:# 这是为了模拟赛题中“预埋”的、缺少交叉注意力的KVCache。# 脚本对每个文档块独立进行一次推理,然后通过`hack_kv`这个“后门”变量,# “偷”出该次推理产生的KVCache。最后将所有块的KVCache拼接起来,# 形成一个完整的、但有瑕疵的`old_kvs`,作为后续“复用”的来源。sampling_params = SamplingParams(temperature=0, max_tokens=1) # 设置采样参数,这里只需要生成1个token来触发prefill# 获取模型内部用于传递元数据的字典recompute_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.recompute_metadata# 计算每个文档块的token长度,用于后续计算重算率doc_length = [len(tokenizer.encode(prefix_prompt))]for doc in doc_prompts:doc_length.append(len(tokenizer.encode(doc)) - 1)doc_length.append(len(tokenizer.encode(q_prompt)) - 1)recompute_metadata["doc_length"] = doc_lengthrecompute_metadata["kv_done"] = False # 标记:此时“旧”的KVCache还没准备好chunk_past_key_values = [] # 用于存储拼接好的“旧”KVCache# 遍历每一个文档块,独立计算其KVCachefor i in range(len(doc_list)):prompts = doc_list[i]llm.generate(prompts, sampling_params) # 对当前块进行推理# 直接访问模型底层,获取所有注意力层llm_layers = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.layers# 遍历模型的每一层for j in range(num_layer):# 从我们修改的qwen2.py中,取出“偷”出来的K,V (`hack_kv`)past_key_values = llm_layers[j].self_attn.hack_kvif i == 0: # 如果是第一个块(头部prompt)temp_k = past_key_values[0][:].clone()temp_v = past_key_values[1][:].clone()chunk_past_key_values.append([temp_k, temp_v]) # 初始化列表else: # 如果是后续的块# 注意:[1:]是为了去掉每个块自带的BOS token,避免重复temp_k = past_key_values[0][1:].clone()temp_v = past_key_values[1][1:].clone()# 将当前块的K,V拼接到之前所有块的K,V之后chunk_past_key_values[j][0] = torch.cat((chunk_past_key_values[j][0], temp_k), dim=0)chunk_past_key_values[j][1] = torch.cat((chunk_past_key_values[j][1], temp_v), dim=0)# 将拼接好的、有瑕疵的KVCache存入模型,供后续使用llm.llm_engine.model_executor.driver_worker.model_runner.model.model.old_kvs = chunk_past_key_values# --- 3.2 执行正式的、带选择性重算的推理 ---# 原理讲解:# 现在,我们将所有文档拼接成一个完整的长文本进行推理。# 因为`kv_done`被设为True,我们修改过的`qwen2.py`会启动选择性重算逻辑,# 它会调用你在`functions.py`中写的算法来决定哪些token复用`old_kvs`,哪些重新计算。recompute_metadata["kv_done"] = True # 标记:“旧”的KVCache已准备就绪,可以开始选择性重算了# 将所有文档块拼接成一个长字符串prompts = ''for doc in doc_list:prompts += doc# 设置正式推理的采样参数sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, ignore_eos=False)# 执行生成任务,这将触发我们的选择性重算逻辑output = llm.generate(prompts, sampling_params)# --- 3.3 计算并验证重算率 ---# 原理讲解:# 推理结束后,从模型元数据中取出`valid_list`(由`qwen2.py`在每层记录的重算token数),# 计算出实际重算token总数,再除以理论最大重算token总数,得到重算率。valid_list = recompute_metadata["valid_list"]recompute_num = 0for u in valid_list:recompute_num += u # 累加每层的重算数量# 计算理论最大重算数 = (文档总token数) * (模型层数)total_num = np.sum(doc_length[1:-1]) * num_layerrecompute_ratio = recompute_num / total_num # 计算重算率if recompute_ratio >= 0.3:print(f"Error!recompute ratio {recompute_ratio} is too large!")else:print("这次回答重算率没超过0.3") # wgh 自己加的输出# --- 3.4 计算性能与精度指标 ---# 原理讲解:# 从vLLM的输出中提取TTFT。对于精度,使用模型的回答与数据集中每一个标准答案进行比较,# 取F1分数最高的那一次作为最终结果。res = output[0].outputs[0].text # 获取模型最终的文本输出print(f"问题: {question}")print(f"模型答案: {res}")print(f"正确答案: {answer}")# 计算TTFTttft = output[0].metrics.first_token_time - output[0].metrics.first_scheduled_timeprint(f"TTFT with cache: {ttft}")ttft_blend.append(ttft) # 存入列表# 初始化当前问题的最高精度分数temp_df1 = 0temp_dpr = 0temp_drecall = 0# 遍历所有可能的标准答案for j in range(len(answer)):# 调用评分函数,注意`extract_after_think`会去掉模型的思考过程df1, dpr, drecall = scorer_all('dureader_all', extract_after_think(res), str(answer[j]))# 保留分数最高的结果if df1 > temp_df1:temp_df1 = df1if dpr > temp_dpr:temp_dpr = dprif drecall > temp_drecall:temp_drecall = drecalldf1 = temp_df1dpr = temp_dprdrecall = temp_drecall# 将当前问题的最终分数存入列表t_df1.append(df1)t_dpr.append(dpr)t_drecall.append(drecall)# --- 3.5 保存结果到JSON文件 ---dict_obj["id"] = iidict_obj["Query"] = questiondict_obj["Model Answer"] = resdict_obj["Dateset Answer"] = answerdict_obj["F1 with Dataset"] = df1dict_obj["Precision with Dataset"] = dprdict_obj["Recall with Dataset"] = drecalldict_obj["TTFT"] = ttftjson.dump(dict_obj, file, indent=4, ensure_ascii=False) # 将字典写入JSON文件file.write(',\n') # 手动写入逗号和换行,为下一个对象做准备# --- 4. 计算并输出最终平均结果 ---# 循环结束后,计算所有问题指标的平均值res_obj = {}res_obj["avg_ttft"] = np.mean(ttft_blend)res_obj["avg_f1 with Dataset Answer"] = np.mean(t_df1)res_obj["avg_precision with Dataset Answer"] = np.mean(t_dpr)res_obj["avg_recall with Dataset Answer"] = np.mean(t_drecall)json.dump(res_obj, file, indent=4, ensure_ascii=False) # 写入平均结果file.write(']\n') # 手动写入JSON数组的结尾# 打印最终的平均结果到控制台
print(f"f1: {np.mean(t_df1)}, precision: {np.mean(t_dpr)}, recall: {np.mean(t_drecall)}, ttft: {np.mean(ttft_blend)}")
文档处理
我们先分析
answer = ex["answers"] # 获取标准答案列表question = ex["question"] # 获取问题字符串# 使用utils.py中的函数构建prompt列表doc_prompts, q_prompt = build_qa_prompt(ex, query_prompt)# 组合成最终的文档块列表,结构为:[头部prompt, 文档1, 文档2, ..., 尾部prompt]doc_list = [prefix_prompt] + doc_prompts + [q_prompt]
我们这里先预处理一下文档
doc_list 大致结构如下
# 开头
你是一个有帮助且知识渊博的助手。你将得到一个问题和一组从知识库中检索到的文档。请仅使用提供的上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请如实地说明。需遵从下面的指令:1、你将得到一个用户问题和一组检索到的文档。2、仅使用提供的上下文来回答问题。3、如果问题无法在上下文中找到答案,请回答:“上下文没有提供足够的信息来回答这个问题。”4、简洁并事实性回答。
文章:# 中间的文档部分
<|User|>标题:为什么鼻子两侧总是红红的?_百度知道黑头的产生 黑头是硬化油脂阻塞物,通常出现在颜面的额头、鼻子等部位,当油脂腺受到过分刺激,毛孔充满多余的油脂而造成阻寒时,在鼻头及其周围部分,经常会有油腻的感觉。这些油脂最终会硬化,经氧化后成为黑色的小点,这些小点就是被称作黑头的油脂阻塞物。 错误去黑头方法:! 1、用手挤:很多人都会用手挤,但由于指甲易藏细菌,所以容易引致皮肤发炎,而且毛孔会越变越大。 2、用刷擦:这种方法只适用于去死皮,如去黑头,作用不大,若大力擦会擦损皮肤。 各路人马总结的有效的去黑头方: 一、盐加牛奶去黑头 1.最好用没有用过的食盐,可以在刚开封时用小瓶单独装起来; 2.每次用4~5滴牛奶兑盐,在盐半溶解状态下开始用来按摩; 3.由于此时的盐未完全溶解仍有颗粒,所以在按摩的时候必须非常非常小力; 4.半分钟后用清水洗去,不能太# 结尾部分请基于上述文章回答下面的问题。
问题:鼻子周围红红的
回答:<|Assistant|><think>
预计算独立的KVCache
通过一个巧妙的循环,用低成本的方式(独立计算+拼接)构建了一个不包含交叉注意力的“旧”KVCache (old_kvs)。
这个 old_kvs 就是您在 functions.py 中进行决策的基准。您的算法需要判断:对于拼接后的文本,哪些位置的token受交叉注意力影响巨大,以至于我们不能使用这个“有瑕疵”的旧KVCache,而必须重新计算它们,从而“修复”这个瑕疵
我们用一个简单的例子来把整个过程走一遍。
目标:模拟“预埋”的KVCache
想象一下,在真实世界里,我们可能会提前把很多文档(比如维基百科页面)单独处理好,存下它们的KVCache。当用户提问时,我们把相关的几个文档的KVCache拿出来,直接拼接在一起,希望能快速得到答案。
这部分代码就是在模拟这个“直接拼接”的过程。
例子 walkthrough
假设我们的输入 doc_list
简化后是这样的(实际是很长的文档块):
doc_list = ["头部prompt", # i = 0"文档A:猫喜欢鱼。", # i = 1"文档B:狗喜欢骨头。",# i = 2"尾部prompt" # i = 3
]
并且,为了简化,我们假设模型只有1层 (num_layer = 1
)。
现在,我们来逐行过一遍代码:
1. 初始化
sampling_params = SamplingParams(temperature=0, max_tokens=1)
# ...
chunk_past_key_values = [] # 准备一个空列表,用来存放我们最终拼接好的KVCache
```chunk_past_key_values` 将会是一个列表,因为模型有 `num_layer` 层,它需要为每一层都存一份KVCache。因为我们假设只有1层,所以它最后只会包含一个元素 `[[K_tensor, V_tensor]]`。**2. `for i in range(len(doc_list))` 循环开始**这个循环会执行4次,每次处理 `doc_list` 中的一个块。---
**第一次循环 (i = 0, 处理 "头部prompt")**```python
prompts = doc_list[0] # prompts = "头部prompt"
llm.generate(prompts, sampling_params) # 模型处理 "头部prompt"
# ...
past_key_values = llm_layers[0].self_attn.hack_kv # “偷”出KVCache
# 假设 KVCache_prompt = ([K_p], [V_p])if i == 0:temp_k = past_key_values[0][:].clone() # temp_k = [K_p]temp_v = past_key_values[1][:].clone() # temp_v = [V_p]chunk_past_key_values.append([temp_k, temp_v]) # 初始化列表
- 发生了什么?:模型只看到了 “头部prompt”,并计算出了它的KVCache。因为是第一次循环 (
i==0
),我们直接把这份KVCache存入chunk_past_key_values
。 - 此时
chunk_past_key_values
的状态:[ [[K_p], [V_p]] ]
第二次循环 (i = 1, 处理 “文档A:猫喜欢鱼。”)
prompts = doc_list[1] # prompts = "文档A:猫喜欢鱼。"
llm.generate(prompts, sampling_params) # 模型处理 "文档A"
# ...
past_key_values = llm_layers[0].self_attn.hack_kv
# 假设 KVCache_A = ([K_A], [V_A])# 进入 else 分支
temp_k = past_key_values[0][1:].clone() # temp_k = [K_A] (去掉BOS token)
temp_v = past_key_values[1][1:].clone() # temp_v = [V_A] (去掉BOS token)# 这是最关键的一步!
chunk_past_key_values[0][0] = torch.cat((chunk_past_key_values[0][0], temp_k), dim=0)
chunk_past_key_values[0][1] = torch.cat((chunk_past_key_values[0][1], temp_v), dim=0)
- 发生了什么?:
- 模型只看到了 “文档A:猫喜欢鱼。”,它完全不知道前面还有个 “头部prompt”。因此,它计算出的
KVCache_A
是孤立的,不包含任何关于 “头部prompt” 的信息。 torch.cat
函数的作用是拼接张量。代码把新算出来的temp_k
([K_A]) 拼接到chunk_past_key_values
中已经存在的K张量 ([K_p]) 的后面。temp_v
同理。
- 模型只看到了 “文档A:猫喜欢鱼。”,它完全不知道前面还有个 “头部prompt”。因此,它计算出的
- 此时
chunk_past_key_values
的状态:[ [[K_p, K_A], [V_p, V_A]] ]
第三次循环 (i = 2, 处理 “文档B:狗喜欢骨头。”)
这个过程和第二次完全一样。模型只看到了 “文档B”,独立地计算出 KVCache_B
,然后代码把它拼接到 chunk_past_key_values
的末尾。
- 此时
chunk_past_key_values
的状态:[ [[K_p, K_A, K_B], [V_p, V_A, V_B]] ]
循环结束
经过所有循环后,chunk_past_key_values
里存储的就是一个强行拼接起来的KVCache。
为什么这个KVCache是“有瑕疵”的?
现在对比一下我们得到的模拟KVCache和理想KVCache的区别:
-
我们得到的模拟KVCache:
K_A
是在模型只看得到"文档A" 的情况下计算出来的。K_B
是在模型只看得到"文档B" 的情况下计算出来的。K_B
的计算完全没有考虑到 “文档A” 的存在。
-
理想的KVCache (完全重算):
- 当模型一次性处理 “头部prompt 文档A 文档B” 时,在计算
K_B
的时候,模型会同时关注 “头部prompt” 和 “文档A” 的内容(这就是交叉注意力 cross-attention)。 - 因此,理想情况下的
K_B
和我们模拟出来的K_B
是完全不同的。
- 当模型一次性处理 “头部prompt 文档A 文档B” 时,在计算
开始推理
这个部分会调用我们的choose_recompute
函数
###标记缓存准备就绪recompute_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.recompute_metadatarecompute_metadata["kv_done"] = True###开始推理prompts = ''for doc in doc_list:prompts += docsampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, ignore_eos = False)output = llm.generate(prompts, sampling_params)recompute_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.recompute_metadata
我将结合一个具体的例子,为您详细讲解当 llm.generate
被调用时,底层发生了什么,以及您的 choose_recompute
函数是如何被一步步使用的。
核心原理:llm.generate
触发的“幕后”连锁反应
当 test.py
调用 output = llm.generate(prompts, sampling_params)
时,vLLM 框架会启动一个完整的、从头到尾的推理流程。因为我们修改了 qwen2.py
,这个流程就变得很特别:
test.py
(指挥官) -> vLLM框架
-> qwen2.py
(调度官) -> functions.py
(您,决策者)
这个连锁反应会在模型的每一层都发生一次。
例子 Walkthrough:深入模型内部
假设我们的模型简化为只有 2层,输入的完整 prompts
在分词后,文档部分有 10个 token。
准备工作:
recompute_metadata["kv_done"]
已经被设为True
。llm.llm_engine...old_kvs
里面已经存好了我们之前模拟的、有瑕疵的10个token的KVCache。
现在,llm.generate
开始执行…
进入第 0 层 (layer_ind = 0
)
-
vLLM的常规操作:模型接收完整的10个token的输入,正常计算出这一层全新的、包含了完整交叉注意力的
q_new
,k_new
,v_new
。 -
qwen2.py
的调度:在Qwen2Attention.forward
函数中,代码检测到kv_done
是True
,于是它暂停了常规流程,开始执行我们的特殊逻辑。 -
调用您的
choose_recompute
函数:qwen2.py
准备好所有“情报”,调用您的函数。此时传递给您函数的参数是:hidden_states
: 完整的10个token的输入隐状态。old_k
,old_v
: 我们之前准备好的、有瑕疵的KVCache。layer_ind
: 0valid_ind
:[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
(在第0层,默认所有token都是候选重算对象)。q
,k
,v
: 就是第1步里算出来的q_new
,k_new
,v_new
。
-
您的算法进行决策:
- 您的
choose_recompute
函数开始执行。假设您的算法(比如我之前给的“渐进式注意力重构”算法)在第0层的策略是:根据v_new
和old_v
的差异,保留差异最大的50%的token进行重算。 - 您的函数经过计算,返回了一个新的
valid_ind
:
[1, 0, 1, 0, 1, 1, 0, 0, 1, 0]
(5个1,5个0)
- 您的
-
qwen2.py
执行您的决策:qwen2.py
拿到了您返回的[1, 0, 1, ...]
。- 它找到所有标记为
0
的位置(第1, 3, 6, 7, 9个token)。 - 在这些位置上,它执行替换操作:用
old_k
和old_v
中对应位置的向量,覆盖掉k_new
和v_new
中对应位置的向量。 - 现在,这一层的KVCache变成了一个“混合体”:5个token用的是新鲜出炉、包含全局信息的K/V,另外5个token用的则是旧的、有瑕疵的K/V。
-
进入下一层:模型使用这个“混合KVCache”完成第0层的注意力计算,然后生成进入第1层的
hidden_states
。一个至关重要的细节:此时,只有那5个被标记为1
的token的hidden_states
才会被计算并传递到下一层。
进入第 1 层 (layer_ind = 1
)
-
vLLM的常规操作:模型接收来自上一层的、只有5个token的
hidden_states
,并为它们计算出全新的q_new
,k_new
,v_new
。 -
qwen2.py
的调度:同样,检测到kv_done
是True
,暂停流程。 -
再次调用您的
choose_recompute
函数:qwen2.py
再次准备“情报”并调用您。这次的参数是:hidden_states
: 只有5个token的隐状态。old_k
,old_v
: 还是那份完整的、有瑕疵的KVCache。layer_ind
: 1valid_ind
:[1, 0, 1, 0, 1, 1, 0, 0, 1, 0]
(这是上一层决策的结果)。q
,k
,v
: 只有5个token的q_new
,k_new
,v_new
。
-
您的算法再次决策:
- 您的
choose_recompute
函数再次执行。假设您在第1层的策略是:在上一层保留的5个token中,根据注意力熵,再淘汰掉40%,只保留最重要的3个。 - 您的函数返回了最终的
valid_ind
:
[1, 0, 0, 0, 1, 1, 0, 0, 0, 0]
(3个1,7个0)。注意,这个结果必须是上一步输入的子集。
- 您的
-
qwen2.py
再次执行决策:qwen2.py
拿到这个最终版的valid_ind
,再次进行K/V替换操作。
这个过程会贯穿模型的所有层,每一层都会调用一次您的 choose_recompute
函数,让您有机会根据当前层的信息,逐步“精炼”出那些真正需要重算的核心token。
当所有层都处理完毕后,llm.generate
才算完成了 prefill
阶段,并返回最终的 output
。
choose_recompute函数分析
这个函数是赛题官方提供的一个基础示例,它实现了 CacheBlend 论文中的一种简化思想。理解它的工作原理是您进行优化的基础。
下面我将为您详细讲解这个函数的原理和功能,并用一个具体的例子来模拟它的执行过程。
1. 原理讲解:核心思想是什么?
这个算法的核心思想非常直接,可以概括为以下三点:
- 一次性决策:它并不在模型的每一层都做决策,而是选择在一个固定的、靠前的层(这里是第二层,
layer_ind == 1
)做一次“一锤子买卖”的决策。 - 后续层沿用:一旦在第二层决定了哪些token需要重算,哪些可以复用,这个决定就会被固定下来,在后续所有更深的层(
layer_ind > 1
)中都沿用这个决定,不再改变。 - 决策标准:信息变化量:它判断一个token是否需要重算的唯一标准是:当把所有文档拼接起来后,这个token的信息表示 (
V
向量) 发生了多大的变化。如果一个token的V
向量相比于它在单个文档中时的V
向量变化巨大,就说明它受到了其他文档的强烈影响,因此必须重算才能获得正确的上下文信息。
2. 功能讲解:代码是如何工作的?
我们来逐行分析这段代码的功能。
# 设置一个固定的重算比例,这里是25%
recompute_ratio = 0.25# 获取头部和尾部prompt的长度,用于后续精确地切片出文档部分的向量
begin = doc_length[0]
end = doc_length[-1]# 获取文档部分的总token数
num_tokens = len(valid_ind)# 根据比例计算出具体要重算多少个token
topk_num = int(num_tokens * recompute_ratio)# 关键判断:只在第二层 (layer_ind == 1) 执行决策逻辑
if layer_ind == 1:if topk_num != 0:# --- 这是算法的核心计算步骤 ---# 1. v[begin:-end] 是当前层根据完整上下文算出的、全新的V向量(只取文档部分)。# 2. old_v 是我们之前预埋的、缺少交叉注意力的、有瑕疵的旧V向量。# 3. (v[...] - old_v)**2 计算两者之差的平方,得到每个元素的差异值。# 4. torch.sum(..., dim=1) 将每个token向量的所有元素差异值相加,# 得到一个代表该token总信息变化量的分数。temp_diff = torch.sum((v[begin:-end] - old_v)**2, dim=1)# 从所有token的变化量分数中,找出分数最高的 topk_num 个token的索引top_indices = torch.topk(temp_diff, k=topk_num).indices# 遍历所有文档tokenfor i in range(num_tokens):# 如果一个token本来是候选(valid_ind[i] == 1),# 但它不在我们刚刚选出的“变化最大”的top_indices列表里,# 那么就把它从重算名单中剔除(设置为0)。if valid_ind[i] == 1 and i not in top_indices:valid_ind[i] = 0else:# 如果计算出的重算数量为0,则所有token都不重算valid_ind = [0] * num_tokens# 返回最终的决策列表。
# 注意:对于所有其他层 (layer_ind != 1),这个if块不会执行,
# 函数会直接返回从上一层接收到的valid_ind,不做任何修改。
return valid_ind
3. 举例说明:模拟执行过程
假设我们的文档部分共有10个token,并且现在正好进入了第二层 (layer_ind = 1
)。
-
参数初始化:
num_tokens
= 10recompute_ratio
= 0.25topk_num
=int(10 * 0.25)
= 2。算法的目标是选出最重要的2个token进行重算。- 此时从第0层传来的
valid_ind
是[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
,表示所有token都是候选。
-
计算信息变化量:
- 代码执行
temp_diff = torch.sum(...)
。 - 假设计算出的10个token的信息变化量分数是:
[0.2, 8.1, 1.5, 0.9, 9.5, 3.2, 0.1, 7.4, 2.8, 1.1]
。
- 代码执行
-
找出最重要的Token:
- 代码执行
torch.topk(temp_diff, k=2)
。 - 它会找到分数最高的两个值:
9.5
(在索引4的位置) 和8.1
(在索引1的位置)。 - 所以,
top_indices
列表就是[4, 1]
。
- 代码执行
-
更新决策列表:
- 代码开始
for
循环,遍历索引0
到9
。 - 索引0:
0 not in [4, 1]
->valid_ind[0]
变成0
。 - 索引1:
1 in [4, 1]
->valid_ind[1]
保持1
。 - 索引2:
2 not in [4, 1]
->valid_ind[2]
变成0
。 - 索引3:
3 not in [4, 1]
->valid_ind[3]
变成0
。 - 索引4:
4 in [4, 1]
->valid_ind[4]
保持1
。 - …以此类推…
- 代码开始
-
返回决策结果:
- 循环结束后,
valid_ind
变成了[0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
。 - 函数将这个列表返回给
qwen2.py
。
- 循环结束后,
-
后续层的影响:
- 当模型进入第三层 (
layer_ind = 2
) 时,choose_recompute
函数会再次被调用。 - 但这一次,
if layer_ind == 1:
的条件不满足,所以函数会直接return valid_ind
,也就是把[0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
这个列表原封不动地返回。 - 在所有更深的层(3, 4, 5…)都是如此。
- 当模型进入第三层 (
通过这个例子,您可以看到,这个示例算法实现了一个非常简单但有效的策略:在早期(第二层)识别出受上下文影响最大的25%的token,然后就锁定这些token,在后续所有层都只为它们进行重算。