当前位置: 首页 > news >正文

【速写】use_cache参数与decode再探讨

序言

纳什最近指出一个小细节,比如在Qwen系列模型中,两个special token:

  • eos_token(<|im_end|>): 151645(im_end 中的 im 指的是 instruct message)
  • pad_token(<|endoftext|>): 151643。

这是很有趣的事情,我们通常理解的 eos 反而是 pad token,而 eos 对应的是 <|im_end|>

我们在进行 sft 训练时,需要对 response 追加 eos,即要监督训练模型什么时候停止输出。

如前所说,eos 对应的 <|im_end|> 中的 im 指的是 instruct message(与之 pair 的是 <|im_start|>)。这俩 token(<|im_start|> <|im_end|> 封装进来的内容即为大家调 api 时用到的 role 和 content 的内容)。

即不管是 user 还是 assistant 它们的 eos 就是 <|im_end|>

这两个 token 也是从 base model (训练/推理)过度到 instruct model 的基础。

本文主要是探讨带kv-cache的解码方法的手写实现。实际上后来发现如果加上use_cache这个参数,可以省略大量的输入。这样是否意味着长上下文是个伪命题呢?其实本质上还是kv-cache太长了。那么之前所谓的sepllm(用点逗来替代长上下文)工作似乎就是隔靴搔痒了。


文章目录

  • 序言
  • 1 Beam Search 解码算法实现
  • 2 实现带KV Cache的Beam Search解码
  • 3 关于在带kv-cache的情况下的use_cache参数


1 Beam Search 解码算法实现

下面是一个使用PyTorch实现的beam search解码算法:

几个小细节:

  • 束搜索可以加入length_penalty,目前model.generate也是有这个参数的,这个惩罚项直接是用来除生成概率的
  • 通常这种需要计算概率相乘的情况,都是避免做乘法,而是使用log p相加
  • 具体实现中应当考虑eos标识符导致的early stop的候选序列,需要提前存储到外面
  • 然后就是关于使用log softmax得到log概率后,这其实是一个负的概率,序列越长,log prob会越小,- log prob 才是越大的,因此在做惩罚的时候,应该是吧 prob / len(seq) ** penality,即长序列的 log prob 会被除掉更多,这是合理的,因为短序列的 - log prob 天然地比 长序列地 - log prob 要更小,这样量纲才是正确的
import torch
import torch.nn.functional as F
from typing import List, Tupledef beam_search(model: torch.nn.Module,initial_input: torch.Tensor,beam_width: int,max_length: int,vocab_size: int,device: torch.device,length_penalty: float = 1.0,early_stopping: bool = True
) -> Tuple[List[List[int]], List[float]]:"""Beam search 解码算法实现参数:model: 用于预测下一个token的模型initial_input: 初始输入张量 (shape: [1, seq_len])beam_width: beam大小max_length: 生成序列的最大长度vocab_size: 词汇表大小device: 使用的设备 (cpu/cuda)length_penalty: 长度惩罚系数 (α), 用于调整对长序列的偏好early_stopping: 是否在所有beam序列达到EOS时提前停止返回:Tuple[List[List[int]], List[float]]: (生成的序列列表, 对应的分数列表)"""# 初始化beamsequences = [[initial_input.tolist()[0]]]  # 初始序列scores = [0.0]  # 初始分数 (log概率)# 存储完整的beam (已经生成EOS的序列)completed_sequences = []completed_scores = []for step in range(max_length):# 如果所有beam都已完成,提前停止if early_stopping and len(sequences) == 0:break# 准备当前步的输入candidates = []for i, seq in enumerate(sequences):# 跳过已经完成的序列if len(seq) > 0 and seq[-1] == 2:  # 假设2是EOS tokencompleted_sequences.append(seq)completed_scores.append(scores[i])continue# 将序列转换为张量input_tensor = torch.tensor([seq], dtype=torch.long).to(device)# 获取模型预测with torch.no_grad():outputs = model(input_tensor)next_token_logits = outputs[:, -1, :]  # 取最后一个token的logitsnext_token_probs = F.log_softmax(next_token_logits, dim=-1)# 获取top-k tokens和它们的log概率topk_probs, topk_tokens = torch.topk(next_token_probs, beam_width, dim=-1)topk_probs = topk_probs.squeeze(0)topk_tokens = topk_tokens.squeeze(0)# 创建候选序列for j in range(beam_width):new_seq = seq.copy()new_seq.append(topk_tokens[j].item())new_score = scores[i] + topk_probs[j].item()candidates.append((new_seq, new_score))# 如果没有候选序列,提前停止if not candidates:break# 选择top-k候选序列,并剪枝candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)sequences, scores = zip(*candidates[:beam_width])sequences = list(sequences)scores = list(scores)# 添加剩余的未完成序列到完成列表中completed_sequences.extend(sequences)completed_scores.extend(scores)# 对完成的序列按分数排序sorted_sequences = [seq for seq, _ in sorted(zip(completed_sequences, completed_scores),key=lambda x: x[1] / (len(x[0]) ** length_penalty),reverse=True)]sorted_scores = sorted(completed_scores,key=lambda score: score / (len(sorted_sequences[completed_scores.index(score)]) ** length_penalty),reverse=True)return sorted_sequences, sorted_scores
  1. 模型要求:

    • 模型应接受形状为 [batch_size, seq_len] 的输入
    • 模型应输出形状为 [batch_size, seq_len, vocab_size] 的logits
  2. 参数说明:

    • initial_input: 初始输入序列(如开始token)
    • beam_width: 控制搜索宽度,值越大结果可能越好但计算成本越高
    • length_penalty: 控制生成长度的偏好(α>1鼓励更长序列,α<1鼓励更短序列)
    • early_stopping: 设为True时,当所有beam都生成EOS token时停止
  3. 返回值:

    • 返回按分数排序的序列列表和对应的分数列表

如何调用?

model = ...  # 你的PyTorch模型
initial_input = torch.tensor([[1]])  # 假设1是开始token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)sequences, scores = beam_search(model=model,initial_input=initial_input,beam_width=5,max_length=50,vocab_size=10000,  # 你的词汇表大小device=device,length_penalty=1.2
)print("Top sequence:", sequences[0])
print("Score:", scores[0])

2 实现带KV Cache的Beam Search解码

这个是关键,之前写过一个k步的贪心解码,使用了递归写法,但不知道怎么在递归中加入kv cache,其实感觉怎么写都不太好。

其实这样看只要写循环就可以了。

实际上transformers的model.generate方法里,对greedy、beam、以及其他采样的方法都是有做 kv cache 缓存优化的:

# 首次调用(处理初始输入)
outputs = model(input_ids, use_cache=True)
logits = outputs.logits
past_key_values = outputs.past_key_values  # 缓存KV# 后续步骤(生成时)
next_token_input = torch.tensor([[new_token]])
outputs = model(next_token_input, past_key_values=past_key_values, use_cache=True)
new_past_key_values = outputs.past_key_values  # 更新缓存

要在Beam Search中为每个候选序列维护独立的KV Cache,我们需要对标准实现进行一些修改。KV Cache可以显著提高自回归模型的推理效率,因为它避免了重复计算先前token的key和value。

以下是带有KV Cache的Beam Search实现:

  • 因为我们已经有cache了,因此实际上每步迭代生成下一个token时,只需要把当前序列的最后一个token输入即可,只是很棒的操作
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict, Anydef beam_search_with_kv_cache(model: torch.nn.Module,initial_input: torch.Tensor,beam_width: int,max_length: int,vocab_size: int,device: torch.device,length_penalty: float = 1.0,early_stopping: bool = True,use_kv_cache: bool = True
) -> Tuple[List[List[int]], List[float]]:"""带KV Cache的Beam Search解码算法参数:model: 用于预测下一个token的模型initial_input: 初始输入张量 (shape: [1, seq_len])beam_width: beam大小max_length: 生成序列的最大长度vocab_size: 词汇表大小device: 使用的设备 (cpu/cuda)length_penalty: 长度惩罚系数early_stopping: 是否在所有beam序列达到EOS时提前停止use_kv_cache: 是否使用KV Cache加速返回:Tuple[List[List[int]], List[float]]: (生成的序列列表, 对应的分数列表)"""# 初始化beamsequences = [[initial_input.tolist()[0]]]scores = [0.0]# 存储KV Cache (每个候选序列一个cache)kv_caches = [None]  # 初始cache为None# 存储完整的beamcompleted_sequences = []completed_scores = []for step in range(max_length):if early_stopping and len(sequences) == 0:breakcandidates = []new_kv_caches = []for i, (seq, score, kv_cache) in enumerate(zip(sequences, scores, kv_caches)):# 跳过已经完成的序列if len(seq) > 0 and seq[-1] == 2:  # 假设2是EOS tokencompleted_sequences.append(seq)completed_scores.append(score)continue# 准备输入 (只使用最后一个token,因为前面的已经cache了)input_tensor = torch.tensor([[seq[-1]]], dtype=torch.long).to(device)# 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input, use_cache=True)next_token_logits = outputs.logits[:, -1, :]new_kv_cache = outputs.past_key_valueselse:# 后续调用,使用KV Cacheoutputs = model(input_tensor, past_key_values=kv_cache, use_cache=True)next_token_logits = outputs.logits[:, -1, :]new_kv_cache = outputs.past_key_valueselse:# 不使用KV Cache的情况full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input, use_cache=False)next_token_logits = outputs.logits[:, -1, :]new_kv_cache = Nonenext_token_probs = F.log_softmax(next_token_logits, dim=-1)# 获取top-k tokenstopk_probs, topk_tokens = torch.topk(next_token_probs, beam_width, dim=-1)topk_probs = topk_probs.squeeze(0)topk_tokens = topk_tokens.squeeze(0)# 创建候选序列for j in range(beam_width):new_seq = seq.copy()new_seq.append(topk_tokens[j].item())new_score = score + topk_probs[j].item()candidates.append((new_seq, new_score, new_kv_cache))# 如果没有候选序列,提前停止if not candidates:break# 选择top-k候选序列candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)top_candidates = candidates[:beam_width]# 解包候选序列sequences = [cand[0] for cand in top_candidates]scores = [cand[1] for cand in top_candidates]kv_caches = [cand[2] for cand in top_candidates]# 添加剩余的未完成序列completed_sequences.extend(sequences)completed_scores.extend(scores)# 对完成的序列按分数排序sorted_pairs = sorted(zip(completed_sequences, completed_scores),key=lambda x: x[1] / (len(x[0]) ** length_penalty),reverse=True)sorted_sequences = [pair[0] for pair in sorted_pairs]sorted_scores = [pair[1] for pair in sorted_pairs]return sorted_sequences, sorted_scores

关键修改说明

  1. KV Cache维护:

    • 每个候选序列现在维护自己的KV Cache
    • 初始时KV Cache为None
    • 第一次处理序列时生成完整的KV Cache
    • 后续步骤只处理最后一个token并更新KV Cache
  2. 模型接口要求:

    • 模型需要支持past_key_values参数来接收先前的KV Cache
    • 模型需要返回past_key_values作为输出的一部分
    • 典型实现方式(如HuggingFace的transformers):
      outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
      next_token_logits = outputs.logits
      past_key_values = outputs.past_key_values
      
  3. 性能优化:

    • 使用KV Cache后,每次前向传播只处理最后一个token(这个很有趣,但是要设置use_cache=True
    • 避免了重复计算先前token的key和value
    • 对于长序列可以显著提高速度

一个简单的调用示例:

# 假设我们有一个支持KV Cache的模型
model = ...  # 例如HuggingFace的GPT2模型
initial_input = torch.tensor([[model.config.bos_token_id]])  # 开始token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 运行带KV Cache的beam search
sequences, scores = beam_search_with_kv_cache(model=model,initial_input=initial_input,beam_width=5,max_length=50,vocab_size=model.config.vocab_size,device=device,length_penalty=1.2,use_kv_cache=True  # 启用KV Cache
)print("Top sequence:", sequences[0])
print("Score:", scores[0])

补注:

在这个部分:

            # 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input, use_cache=True)

上,输出的full_input 的size是[1, 1, seqlen],理论上应该是[1, seqlen]才对,因此要么是

            # 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor(seq, dtype=torch.long).to(device)outputs = model(full_input, use_cache=True)

要么是:

            # 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input.squeeze(0), use_cache=True)

这样测试跑通应该是没有问题的


3 关于在带kv-cache的情况下的use_cache参数

比如之前手写的一个贪心解码算法:

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport torch
import logging
from copy import deepcopy
from functools import wraps
from torch.nn import functional as Ffrom transformers import AutoTokenizer, AutoModelForCausalLM# Standard greedy decode
# @param model: Huggingface model object
# @param tokenizer: Huggingface tokenizer Object
# @param prompt: Str
# @param max_length: Int, the number of tokens to be generated
# @param device: Str, e.g. "cuda" or "cpu"
# @param kv_cache: Boolean, whether to use KV-cache to accelerate, if True then large memory will be consumed
# @return generated_text: Str
# @return generated_token_prob: List[Tuple(Int, Str, Float)], `len(generated_id_prob)` is `max_length`, indicating the generated probability of each token
# @return generated_logits: Tuple[FloatTensor(1, n_vocab)], `len(generated_logits)` is `max_length`, indicating the logits when each token is generated
def greedy_decode(model,tokenizer,prompt, max_length,device = "cuda",kv_cache = True,):inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)	# Str => Long(1, n_tokens)past_key_values = Nonegenerated_token_probs = list()generated_logits = list()model.gradient_checkpointing_enable()for i in range(max_length):logging.info(f"Round {i}: {past_key_values.key_cache[0].size() if past_key_values is not None else None}")outputs = model(inputs, past_key_values=past_key_values)logits = outputs.logits	# Float(1, n_tokens + i + 1, n_vocab), where `n_vocab` is 151936 in DeepSeek-R1-Distill-Qwenif kv_cache:past_key_values = outputs.past_key_values	# Dictlike[key_cache: Float(1, 2, X, hidden_size), value_cache: Float(1, 2, X, hidden_size)], where X = (i + 1) * (n_tokens + i / 2)next_token_probs = F.softmax(logits[:, -1, :], dim=-1)	# Float(1, n_tokens + i + 1, n_vocab) => Float(1, n_vocab)next_token_id = torch.argmax(next_token_probs, dim=-1)	# Float(1, n_vocab) => Long(1, )next_token_prob = next_token_probs[0, next_token_id].item()	# Float(1, n_vocab) => Float()next_token = tokenizer.decode(next_token_id[0].item(), skip_special_tokens=False)	# Long(1, ) => Strinputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1)	# Long(1, n_tokens + i) => Long(1, n_tokens + i + 1)generated_token_probs.append((next_token_id.item(), next_token, next_token_prob))generated_logits.append(logits[:, -1, :])generated_text = tokenizer.decode(token_ids = inputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True,)	# Long(1, n_tokens + max_length) => Strreturn generated_text, generated_token_probs, tuple(generated_logits)

实际上除了第一次输入外,接下来都可以用最后一个token作为输入,而不需要把之前整个一长串的input都输入到model中去:

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport torch
import logging
from copy import deepcopy
from functools import wraps
from torch.nn import functional as Ffrom transformers import AutoTokenizer, AutoModelForCausalLM# Standard greedy decode
# @param model: Huggingface model object
# @param tokenizer: Huggingface tokenizer Object
# @param prompt: Str
# @param max_length: Int, the number of tokens to be generated
# @param device: Str, e.g. "cuda" or "cpu"
# @param kv_cache: Boolean, whether to use KV-cache to accelerate, if True then large memory will be consumed
# @return generated_text: Str
# @return generated_token_prob: List[Tuple(Int, Str, Float)], `len(generated_id_prob)` is `max_length`, indicating the generated probability of each token
# @return generated_logits: Tuple[FloatTensor(1, n_vocab)], `len(generated_logits)` is `max_length`, indicating the logits when each token is generated
def greedy_decode(model,tokenizer,prompt, max_length,device = "cuda",kv_cache = True,):inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)	# Str => Long(1, n_tokens)past_key_values = Nonegenerated_token_probs = list()generated_logits = list()model.gradient_checkpointing_enable()for i in range(max_length):logging.info(f"Round {i}: {past_key_values.key_cache[0].size() if past_key_values is not None else None}")if kv_cache:if i == 0:outputs = model(inputs, past_key_values=past_key_values)else:outputs = model(inputs[:, -1].unsqueeze(0), past_key_values=past_key_values, use_cache=True)else:outputs = model(inputs, past_key_values=None)logits = outputs.logits	# Float(1, n_tokens + i + 1, n_vocab), where `n_vocab` is 151936 in DeepSeek-R1-Distill-Qwenif kv_cache:past_key_values = outputs.past_key_values	# Dictlike[key_cache: Float(1, 2, X, hidden_size), value_cache: Float(1, 2, X, hidden_size)], where X = (i + 1) * (n_tokens + i / 2)next_token_probs = F.softmax(logits[:, -1, :], dim=-1)	# Float(1, n_tokens + i + 1, n_vocab) => Float(1, n_vocab)next_token_id = torch.argmax(next_token_probs, dim=-1)	# Float(1, n_vocab) => Long(1, )next_token_prob = next_token_probs[0, next_token_id].item()	# Float(1, n_vocab) => Float()next_token = tokenizer.decode(next_token_id[0].item(), skip_special_tokens=False)	# Long(1, ) => Strinputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1)	# Long(1, n_tokens + i) => Long(1, n_tokens + i + 1)generated_token_probs.append((next_token_id.item(), next_token, next_token_prob))generated_logits.append(logits[:, -1, :])generated_text = tokenizer.decode(token_ids = inputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True,)	# Long(1, n_tokens + max_length) => Strreturn generated_text, generated_token_probs, tuple(generated_logits)

这个确实是很有帮助的,能加速推理很多。这个原理其实很简单,因为只需要KVcache与最后一个token就可以计算得到下一层的注意力权重(其实就是下一轮生成的KVcache),然后倒是发现deepseek在生成图像链接时出错了,难得逮到DeepSeek犯错的时候(生成图片链接失败):

在这里插入图片描述

use_cache=True 时,Transformer 模型的前向传播会启用 KV Cache 机制,这是解码阶段(如文本生成)的核心优化手段。以下是其具体执行逻辑和缓存内容的详细说明:

模型会缓存每一层(Layer)的 Key 矩阵(K)Value 矩阵(V),这些矩阵来自历史 token 的自注意力计算。具体来说:

  • K Cache: 形状为 [batch_size, num_heads, seq_len, head_dim]
  • V Cache: 形状为 [batch_size, num_heads, seq_len, head_dim]
  • Query(Q) 是当前 token 的向量,每次生成时需重新计算。
  • Key/Value 是历史 token 的向量,生成新 token 时可直接复用,无需重复计算。

看起来如果设置了use_cache = True的话,其实根本就不需要再手动更新kv_cache了,但这个事情还不太好直接验证。只能从运行时间上分辨。

  1. 首次调用(处理完整输入序列)
outputs = model(input_ids, use_cache=True)
  • 计算步骤:
    1. 对输入的所有 token 计算完整的自注意力(包括 Q、K、V)。
    2. 将每一层的 K 和 V 存入 past_key_values(形状为 [num_layers, 2, batch_size, num_heads, seq_len, head_dim])。
    3. 返回最后一个 token 的 logits 和缓存的 past_key_values
  1. 后续调用(生成新 token)
outputs = model(new_token, past_key_values=past_key_values, use_cache=True)
  • 计算步骤:
    1. 仅计算新 token 的 Q(因为 K/V 已缓存)。
    2. 将新 token 的 Q 与缓存的 K/V 计算注意力分数:
      Attention(Q_new, K_cache, V_cache) = softmax(Q_new @ K_cache^T / √d) @ V_cache
      
    3. 将新 token 自身的 K/V 追加到缓存中,更新 past_key_values
    4. 返回新 token 的 logits 和更新后的 past_key_values

KV Cache 的代码级实现(以 HuggingFace 为例)

缓存的数据结构

past_key_values = [(K_layer1, V_layer1),  # 第1层的K/V(K_layer2, V_layer2),  # 第2层的K/V...                    # 所有层的K/V
]

关键代码逻辑

# 在模型的自注意力层中(简化版)
if use_cache:# 合并历史K/V与新K/Vkey_states = torch.cat([past_key_values[0], current_key], dim=2])  # 沿seq_len维度拼接value_states = torch.cat([past_key_values[1], current_value], dim=2])# 更新缓存present_key_values = (key_states, value_states)
else:present_key_values = None

KV Cache 的显存占用分析

假设以下参数:

  • 模型层数 L(如 LLaMA-7B 有 32 层)
  • 注意力头数 H(如 32)
  • 头维度 D(如 128)
  • 序列长度 S
  • 批大小 B
  • 数据类型 dtype(如 float16 占 2字节)

缓存总大小

显存 ≈ L × 2 × B × H × S × D × dtype_size

例如:LLaMA-7B 生成 1024 token 时,单样本的缓存约占用 32×2×1×32×1024×128×2 = 512MB


KV Cache 的优化效果

操作计算复杂度(无缓存)计算复杂度(有缓存)
生成第 N 个 tokenO(N²)O(N)
显存占用O(1)O(N)
  • 速度提升:生成 1000 token 时,理论加速约 1000 倍(从 1M 次计算降到 1K 次)。
  • 代价:显存随序列长度线性增长。

相关文章:

  • 【嵌入式系统设计师(软考中级)】第三章:嵌入式系统软件基础知识——①软件及操作系统基础
  • 电脑端音乐播放器推荐:提升你的听歌体验!
  • 免费多线程下载工具
  • 数字人教学技术与产品方案的全面解析
  • 【论信息系统项目的质量管理】
  • MySQL创建了一个索引表,如何来验证这个索引表是否使用了呢?
  • 在Windows 境下,将Redis和Nginx注册为服务。
  • 自适应主从复制模拟器的构建与研究
  • 使用ACE-Step在本地生成AI音乐
  • 双向链表专题
  • DAY05:深入解析生命周期与钩子函数
  • MYSQL事务原理分析(三)
  • nginx配置sse流传输问题:直到所有内容返回后才往下传输
  • java反序列化commons-collections链6
  • LVGL(lv_switch开关)
  • 输出重定向
  • 位运算题目:黑板异或游戏
  • 牛客周赛 Round 92
  • ComfyUI的K采样器参数详解:实战演示
  • Python 实现失败重试功能的几种方法
  • 福建宁德市长张永宁拟任设区市党委正职,曾获评全国优秀县委书记
  • 这些网红果蔬正在收割你的钱包,营养师:吃了个寂寞
  • 乘联分会:上半年车市价格竞争温和,下半年价格战或再开启
  • 技术派|巴基斯坦导弹:让印度保持克制的“定海神针”?
  • 著名蒙古族音乐学者马•斯尔古愣逝世,享年86岁
  • 巴基斯坦全面恢复领空开放