【大模型推理】ScheduleBatch 学习
很好问题!让我详细解释EXTEND和预填充(prefill)的概念,以及它们与解码(decode)的区别。
1. EXTEND 的含义
定义
EXTEND模式指的是扩展序列的操作,即处理新的输入token并将其KV缓存写入缓存池。
具体场景
# 场景1: 新请求的首次处理
输入: "Hello, how are you?"
操作: 将整个输入序列的KV缓存写入KV缓存池# 场景2: 长文本的分块处理
输入: "This is a very long document that needs to be processed in chunks..."
操作: 分块处理,每次处理一部分token# 场景3: 流式输入追加
已处理: "The weather is"
新输入: " nice today"
操作: 只处理新追加的token
在代码中的体现
def prepare_for_extend(self):self.forward_mode = ForwardMode.EXTEND# 计算需要扩展的tokeninput_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]extend_num_tokens = sum(len(ids) for ids in input_ids)# 这些token都是新的,需要计算并缓存它们的KV
2. Prefill (预填充) vs EXTEND
传统意义上的Prefill
Prefill通常指首次填充,即处理整个输入提示(prompt):
# 传统Prefill
输入: "Translate the following English to French: 'Hello world'"
输出: 无(只填充KV缓存)
SGLang中的EXTEND概念
在SGLang中,EXTEND是更广义的概念:
# 包含传统Prefill,但更广泛
class ForwardMode(Enum):EXTEND # 扩展序列(包含预填充和续写)DECODE # 自回归解码生成MIXED # 混合模式IDLE # 空闲
关键区别
| 方面 | 传统Prefill | SGLang EXTEND |
|---|---|---|
| 范围 | 只处理初始prompt | 处理任何新token |
| 输出 | 通常不生成输出 | 可能生成输出token |
| 缓存 | 填充初始KV缓存 | 可能利用已有前缀缓存 |
3. EXTEND 的具体工作流程
3.1 前缀缓存利用
def init_next_round_input(self, tree_cache=None):self.fill_ids = self.origin_input_ids + self.output_ids# 查询前缀缓存,找到可重用的部分if tree_cache is not None:self.prefix_indices, self.last_node = tree_cache.match_prefix(rid=self.rid, key=self.adjust_max_prefix_ids())# 只处理新的部分self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
示例:
已有缓存: [1, 2, 3, 4] # 前缀索引
完整序列: [1, 2, 3, 4, 5, 6, 7] # fill_ids
EXTEND部分: [5, 6, 7] # 只需要处理这些新token
3.2 内存分配策略
def prepare_for_extend(self):# 分配请求槽位req_pool_indices = self.alloc_req_slots(len(self.reqs))# 分配KV缓存位置if self.token_to_kv_pool_allocator.page_size == 1:out_cache_loc = self.alloc_token_slots(extend_num_tokens)else:# 分页分配out_cache_loc = self.alloc_paged_token_slots_extend(...)# 建立映射关系write_req_to_token_pool_triton(...)
4. EXTEND vs DECODE 的对比
4.1 处理模式对比
# EXTEND 模式 - 并行处理多个token
输入: [token1, token2, token3, ...] # 可变长度序列
处理: 并行计算所有token的注意力
输出: 可能生成下一个token的概率分布# DECODE 模式 - 单步自回归
输入: [last_token] # 单个token(上一步的输出)
处理: 基于整个历史生成下一个token
输出: 下一个token的ID
4.2 计算特征对比
| 特征 | EXTEND模式 | DECODE模式 |
|---|---|---|
| 计算复杂度 | O(n²) - 全注意力 | O(n) - 增量注意力 |
| 并行度 | 高 - 多个token并行 | 低 - 单个token |
| 内存访问 | 不规则 - 可变长度 | 规则 - 固定批次 |
| KV缓存 | 写入新的缓存位置 | 读取现有缓存 |
4.3 实际代码对比
EXTEND 准备
def prepare_for_extend(self):# 处理可变长度序列input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]extend_num_tokens = sum(len(ids) for ids in input_ids) # 可变总数# 扁平化处理input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64)# 形状: [extend_num_tokens]
DECODE 准备
def prepare_for_decode(self):# 处理固定批次(每个请求1个token)bs = len(self.reqs)# 输入是上一步的输出self.input_ids = self.output_ids # 形状: [bs]# 每个请求分配1个新KV位置self.out_cache_loc = self.alloc_token_slots(bs) # 形状: [bs]
5. EXTEND 的使用场景
5.1 新请求初始化
# 用户发送新请求
request = Req(rid="req1",origin_input_text="What is the capital of France?",origin_input_ids=[123, 456, 789, ...],sampling_params=sampling_params
)# 创建批次并准备EXTEND
batch = ScheduleBatch.init_new([request], ...)
batch.prepare_for_extend() # 首次处理整个prompt
5.2 长文本分块处理
# 长文档处理
def process_long_document(document_tokens, chunk_size=512):for i in range(0, len(document_tokens), chunk_size):chunk = document_tokens[i:i+chunk_size]# 准备EXTEND处理这个分块req.fill_ids = existing_tokens + chunkreq.init_next_round_input(tree_cache)batch.prepare_for_extend() # 处理这个分块# 执行模型前向传播...
5.3 流式输入追加
# 流式对话场景
def handle_streaming_input(req, new_input_tokens):# 追加新输入到现有序列req.output_ids.extend(previous_output)req.fill_ids = req.origin_input_ids + req.output_ids + new_input_tokens# 只EXTEND新输入的部分req.init_next_round_input(tree_cache)batch.prepare_for_extend()
6. 性能优化考虑
6.1 EXTEND的挑战
# 问题: 可变长度导致计算效率低
extend_num_tokens = sum(len(ids) for ids in input_ids) # 可能很大且不规则# 解决方案: 分块和填充
if enable_chunked_prefill:# 将长序列分成固定大小的块process_in_chunks(req, chunk_size=256)
6.2 前缀缓存的价值
def init_next_round_input(self, tree_cache):# 利用前缀缓存避免重复计算self.prefix_indices, self.last_node = tree_cache.match_prefix(...)self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)# 只有 extend_input_len 个token需要实际计算
性能收益:
- 减少计算量:只计算新token
- 减少内存分配:重用已有KV缓存位置
- 提高吞吐量:更快的处理速度
7. 实际执行流程示例
完整请求生命周期
# 1. 新请求 - EXTEND处理整个prompt
batch.prepare_for_extend()
# 输入: [What, is, the, capital, of, France, ?]
# 输出: 第一个生成token的概率分布# 2. 解码循环 - DECODE模式
for step in range(max_steps):batch.prepare_for_decode()# 输入: [last_generated_token]# 输出: 下一个token的概率分布# 采样得到新tokennew_token = sample_from_logits(logits)req.output_ids.append(new_token)# 检查是否结束req.check_finished()if req.finished():break# 3. 如果有新输入 - 再次EXTEND
if has_new_user_input:req.init_next_round_input(tree_cache)batch.prepare_for_extend() # 处理新输入
8. 与Profile的关系
Profile模式的作用
PROFILE模式通常用于性能分析,而不是实际推理:
# Profile模式 - 测量性能,不产生实际输出
if profile_mode:# 测量EXTEND操作的性能start_time = time.time()batch.prepare_for_extend()extend_time = time.time() - start_time# 测量DECODE操作的性能 start_time = time.time()batch.prepare_for_decode()decode_time = time.time() - start_timelogger.info(f"EXTEND: {extend_time:.3f}s, DECODE: {decode_time:.3f}s")
实际使用中的Profile
在真实系统中,Profile可能用于:
- 容量规划:测量系统能处理的最大批次大小
- 优化调试:识别性能瓶颈
- 资源分配:根据EXTEND/DECODE比例调整资源
总结
EXTEND在SGLang中是一个核心概念,它:
- 广义的预填充:不仅处理初始prompt,还处理任何新token
- 前缀缓存感知:智能利用共享前缀避免重复计算
- 性能关键路径:通常比DECODE更计算密集,需要精心优化
- 流式处理基础:支持长文本分块和流式输入
与传统的Prefill相比,EXTEND更灵活,支持更复杂的应用场景,是现代LLM推理服务的重要组成部分。
我们继续使用EXTEND模式来处理新的输入,这类似于预填充,但有一些关键区别。
在传统的预填充(prefill)中,我们通常是指处理整个提示(prompt)以生成第一个输出token。而在SGLang中,EXTEND模式用于处理任何新的输入token,这些新输入token可能是:
整个提示(第一次处理)
长提示中的后续分块(当使用分块处理时)
在流式对话中,用户新输入的消息(追加到现有对话之后)
所以,EXTEND模式比传统的预填充更通用。它允许我们利用前缀缓存(如果存在)来避免重复计算,只计算新token的KV缓存。
具体来说,当我们有新的用户输入时,我们会将新输入追加到现有序列中,然后通过EXTEND模式处理这些新token。这与第一次预填充的不同之处在于:
第一次预填充(整个提示)时,前缀缓存是空的,所以我们需要计算整个提示的KV缓存。
当处理后续分块或新输入时,我们可能已经缓存了部分序列(前缀),所以只需要计算新追加的token。
因此,EXTEND模式是一种增量式的预填充,它利用已有的缓存来减少计算量。
在代码中,我们通过以下步骤准备EXTEND:
将新输入追加到请求的fill_ids(即完整序列:原始输入+已生成输出+新输入)
调用init_next_round_input来更新前缀缓存和计算需要扩展的token数量(extend_input_len)
调用prepare_for_extend来分配KV缓存位置并构建模型输入。
这样,模型只会处理新添加的token,而不会重新处理整个序列。
举个例子:
假设我们有一个请求,已经处理了初始提示:“Hello, how are you?”,并生成了一些输出:“I’m fine.”。现在用户又输入了新的消息:“What about you?”。
那么,完整的序列是:“Hello, how are you? I’m fine. What about you?”。
但是,由于我们已经缓存了"Hello, how are you? I’m fine."的KV,所以我们只需要处理新追加的"What about you?"部分。
这就是EXTEND模式与传统预填充的区别:它允许我们利用缓存来增量处理,而不是每次都要处理整个序列。
因此,在代码中,当我们有新的用户输入时,我们会再次进入EXTEND模式,但只处理新输入的部分,从而大大提高效率。
非常好问题!让我详细解释"再次EXTEND"的含义,以及它与传统Prefill的根本区别。
非常好问题!让我详细解释"再次EXTEND"的含义,以及它与传统Prefill的根本区别。
1. 场景示例:多轮对话
用户与AI的对话流程
# 第一轮对话
用户: "What is the capital of France?" # 初始输入
AI: "The capital of France is Paris."# 第二轮对话
用户: "And what about Germany?" # 新输入
AI: "The capital of Germany is Berlin."
在系统中的处理流程
# 第一轮:处理初始问题
req1 = Req(rid="conv1", input_text="What is the capital of France?")
batch.prepare_for_extend() # 传统Prefill:处理整个问题
# → 生成回答: "The capital of France is Paris."# 第二轮:处理追加问题
req1.add_input("And what about Germany?") # 追加新输入
req1.init_next_round_input(tree_cache)
batch.prepare_for_extend() # 再次EXTEND:只处理新问题
# → 生成回答: "The capital of Germany is Berlin."
2. 再次EXTEND vs 传统Prefill的关键区别
2.1 计算范围不同
传统Prefill(首次EXTEND)
# 处理整个序列
输入序列: [What, is, the, capital, of, France, ?]
需要计算的token: 7个
KV缓存写入: 7个新位置
前缀缓存: [] (空的,没有可重用的)
再次EXTEND(增量EXTEND)
# 只处理新增部分
完整序列: [What, is, the, capital, of, France, ?, The, capital, of, France, is, Paris, ., And, what, about, Germany, ?]↑ 已缓存部分 ↑ ↑ 已生成部分 ↑ ↑ 新增输入部分 ↑
前缀缓存: [What, is, the, capital, of, France, ?, The, capital, of, France, is, Paris, .]
需要计算的token: [And, what, about, Germany, ?] # 只有5个
KV缓存写入: 5个新位置
2.2 前缀缓存利用
代码层面的差异
def init_next_round_input(self, tree_cache):# 构建完整序列self.fill_ids = self.origin_input_ids + self.output_ids + new_input_ids# 关键:查询前缀缓存,找到可重用的部分if tree_cache is not None:self.prefix_indices, self.last_node = tree_cache.match_prefix(key=self.adjust_max_prefix_ids() # 查询已有缓存)# 只计算新增部分self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)# 首次Prefill: extend_input_len = 整个序列长度# 再次EXTEND: extend_input_len = 新增部分长度
2.3 内存分配差异
首次Prefill的内存分配
def prepare_for_extend(self):# 整个序列都是新的input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs] # 整个序列extend_num_tokens = sum(len(ids) for ids in input_ids) # 可能很大# 需要分配大量KV缓存位置out_cache_loc = self.alloc_token_slots(extend_num_tokens)
再次EXTEND的内存分配
def prepare_for_extend(self):# 只有新增部分是新的input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs] # 只有新增部分extend_num_tokens = sum(len(ids) for ids in input_ids) # 通常较小# 只需要分配少量KV缓存位置out_cache_loc = self.alloc_token_slots(extend_num_tokens) # 更少的内存需求
3. 具体技术实现细节
3.1 前缀缓存的工作原理
# 假设的对话历史
history = {"session1": [[1, 2, 3, 4], # "What is the capital"[1, 2, 3, 4, 5], # "What is the capital of" [1, 2, 3, 4, 5, 6] # "What is the capital of France"]
}def match_prefix(self, key):"""在缓存树中查找最长的匹配前缀"""# key = [1, 2, 3, 4, 5, 6, 10, 11] # 完整序列# 返回: [1, 2, 3, 4, 5, 6] 的缓存位置# 剩余: [10, 11] 需要处理return prefix_indices, last_node
3.2 序列状态的变化
首次Prefill后
req = Req(...)
# 初始状态
req.origin_input_ids = [1, 2, 3, 4, 5, 6] # "What is the capital of France?"
req.output_ids = [7, 8, 9, 10, 11, 12] # "The capital of France is Paris."
req.fill_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]# 前缀缓存状态
req.prefix_indices = [100, 101, 102, 103, 104, 105] # 对应KV缓存位置
收到新输入后
# 用户新输入
new_input = [13, 14, 15, 16, 17] # "And what about Germany?"# 更新序列
req.fill_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
# ↑ 已缓存 ↑ ↑ 新输入 ↑# 重新查询前缀缓存
req.init_next_round_input(tree_cache)
# prefix_indices = [100, 101, 102, 103, 104, 105, ...] # 可能扩展到输出部分
# extend_input_len = 5 # 只有新输入需要处理
4. 性能优势分析
4.1 计算复杂度对比
传统方法(无缓存)
# 每次都要处理整个历史
总计算量 = O(L²) # L是累计序列长度
第二轮: O(17²) = 289
第三轮: O(22²) = 484 # 越来越慢
再次EXTEND(有缓存)
# 只处理新增部分
总计算量 = O(L_existing²) + O(L_new²) # 但L_existing的KV已缓存
第二轮: O(5²) = 25 # 只处理5个新token
第三轮: O(5²) = 25 # 保持稳定
4.2 内存使用对比
KV缓存重用
# 首次Prefill
KV缓存分配: 位置[100-111] # 12个位置# 再次EXTEND
KV缓存分配: 位置[112-116] # 只新增5个位置
# 位置[100-111] 被重用,不需要重新计算
5. 实际应用场景
5.1 多轮对话系统
class ChatSession:def __init__(self, session_id):self.session_id = session_idself.req = Nonedef add_message(self, user_input):if self.req is None:# 首次Prefillself.req = Req(rid=self.session_id, input_text=user_input)batch.prepare_for_extend()else:# 再次EXTENDself.req.add_input(user_input)self.req.init_next_round_input(tree_cache)batch.prepare_for_extend()# 执行推理并返回结果return self.execute_inference()
5.2 长文档处理
def process_long_document(document, chunk_size=512):req = Req(rid="doc1", input_text=document[:chunk_size])# 首次Prefill第一个分块batch.prepare_for_extend()# 处理剩余分块for i in range(chunk_size, len(document), chunk_size):chunk = document[i:i+chunk_size]req.add_input(chunk)req.init_next_round_input(tree_cache)batch.prepare_for_extend() # 再次EXTEND处理新分块
5.3 流式输入应用
# 实时翻译或语音识别
def stream_processor():req = Req(rid="stream1")while True:new_tokens = get_new_tokens_from_stream()if new_tokens:req.add_input(new_tokens)req.init_next_round_input(tree_cache)if req.extend_input_len > 0: # 有新内容需要处理batch.prepare_for_extend()yield generate_output()
6. 与传统方法的根本区别
6.1 状态持续性
传统Prefill(无状态)
# 每次请求都是独立的
def handle_request(request_text):req = Req(input_text=request_text) # 新建请求batch.prepare_for_extend() # 完整处理# 完成后丢弃所有状态
再次EXTEND(有状态)
# 保持会话状态
def handle_conversation(session_id, new_input):req = get_session(session_id) # 获取已有请求req.add_input(new_input) # 追加新输入req.init_next_round_input(tree_cache) # 智能确定需要处理的部分batch.prepare_for_extend() # 增量处理# 保持所有状态供下次使用
6.2 缓存策略
传统:每次重新计算
用户: "A B C" → 计算: A, B, C
AI: "X Y Z"
用户: "D E" → 计算: A, B, C, X, Y, Z, D, E # 重复计算!
再次EXTEND:智能重用
用户: "A B C" → 计算: A, B, C → 缓存: [A, B, C]
AI: "X Y Z" → 计算: X, Y, Z → 缓存: [A, B, C, X, Y, Z]
用户: "D E" → 重用: [A, B, C, X, Y, Z] → 只计算: D, E
7. 代码层面的具体差异
7.1 init_next_round_input 的关键作用
def init_next_round_input(self, tree_cache, enable_hierarchical_cache=False):# 更新完整序列self.fill_ids = self.origin_input_ids + self.output_ids# 关键区别点:if tree_cache is not None:# 再次EXTEND:这里能找到大量匹配的前缀# 首次Prefill:这里可能只能找到很少或没有匹配self.prefix_indices, self.last_node = tree_cache.match_prefix(key=self.adjust_max_prefix_ids())# 计算需要实际处理的部分self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)# 首次Prefill: extend_input_len ≈ 整个序列长度# 再次EXTEND: extend_input_len ≪ 整个序列长度
7.2 prepare_for_extend 的实际影响
def prepare_for_extend(self):# 输入数据收集 - 关键差异点input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]# 首次Prefill: input_ids 包含整个prompt# 再次EXTEND: input_ids 只包含新增部分extend_num_tokens = sum(len(ids) for ids in input_ids)# 这个数字在再次EXTEND时显著变小# 内存分配相应减少out_cache_loc = self.alloc_token_slots(extend_num_tokens)
总结
再次EXTEND与传统Prefill的根本区别:
| 方面 | 传统Prefill | 再次EXTEND |
|---|---|---|
| 计算范围 | 整个输入序列 | 只新增部分 |
| 缓存利用 | 无或很少 | 大量重用已有缓存 |
| 内存分配 | 分配整个序列 | 只分配新增部分 |
| 性能特征 | 计算密集型 | 相对轻量 |
| 应用场景 | 新请求初始化 | 会话延续、流式输入 |
核心价值:再次EXTEND使得LLM能够以增量方式处理输入,而不是每次都重新处理整个历史,这对于多轮对话、长文档处理和流式应用至关重要,能显著提升性能和用户体验。
在 SGLang(以及大多数现代 LLM 推理引擎,如 vLLM)中,extend 模式(也常被称为 prefill 模式)是指 对一个新请求(或新 chunk)的输入序列进行首次 KV Cache 填充和 logits 计算的过程。
一、核心定义
ForwardMode.EXTEND:表示当前 batch 中的所有请求都在执行 prefill(首次处理)。- 它对应的是 从输入 token 到生成第一个输出 token 之前的完整前向计算。
✅ 简单说:
extend = prefill,只是 SGLang 用词为 “extend”。
二、为什么叫 “extend”?
因为:
- 每个请求可能已有 共享前缀(来自
RadixCache/ChunkCache) - 当前操作是 “扩展”这个前缀,把 尚未计算的 token 部分(即
extend_input_len)填入 KV Cache
例如:
- 请求已有前缀 100 个 token(已缓存)
- 输入总长 120 个 token
- 则
extend_input_len = 20,需要 extend 这 20 个 token
三、与 decode 模式的对比
| 特性 | extend(prefill) | decode(自回归生成) |
|---|---|---|
| 输入长度 | 可变(1 ~ 几千) | 固定为 1(每个请求 1 个 token) |
| KV Cache 操作 | 写入多个新 token 的 K/V | 写入 1 个新 token 的 K/V |
| 计算量 | 大(O(n²) Attention) | 小(O(n) Attention) |
| batch shape | 不规则(ragged) | 规则(可启用 CUDA Graph) |
| 是否支持 CUDA Graph | ❌ 否 | ✅ 是 |
| 典型场景 | 请求首次进入、chunked prefill | 生成第 2、3、4… 个 token |
四、在 SGLang 中的关键字段(ForwardBatch)
当 forward_mode == ForwardMode.EXTEND 时,以下字段有效:
extend_num_tokens: 本次 batch 总共要 extend 的 token 数extend_seq_lens: 每个请求要 extend 的 token 数(如[20, 30, 10])extend_prefix_lens: 每个请求已缓存的前缀长度(如[100, 50, 0])positions: 每个 extend token 的绝对位置(如[100,101,...,119, 50,51,...,79, 0,1,...,9])extend_start_loc: 每个请求在 flatteninput_ids中的起始偏移
五、工作流程示例
# 请求:input_ids = [1,2,3,4,5],无前缀缓存
req.prefix_indices = [] # 已缓存 0 个
req.fill_ids = [1,2,3,4,5] # 完整序列
req.extend_input_len = 5 # 需要 extend 5 个 token# 调度器分配 KV slots
out_cache_loc = [100, 101, 102, 103, 104]# 写入 ReqToTokenPool
req_to_token_pool[req.req_pool_idx, 0:5] = [100,101,102,103,104]# 构建 ForwardBatch
forward_batch = ForwardBatch(forward_mode=ForwardMode.EXTEND,input_ids=[1,2,3,4,5],positions=[0,1,2,3,4],extend_prefix_lens=[0],extend_seq_lens=[5],out_cache_loc=[100,101,102,103,104],...
)# ModelRunner 执行前向
logits = model_runner.forward(forward_batch) # 输出 5 个位置的 logits
next_token = sample(logits[-1]) # 通常只采样最后一个
六、相关模式
SGLang 还定义了其他与 extend 相关的模式:
ForwardMode.MIXED:batch 中 同时包含 extend 和 decode 请求(用于 overlap schedule)ForwardMode.DRAFT_EXTEND:在 EAGLE 草稿模型 中执行 extendForwardMode.TARGET_VERIFY:在 主模型 中验证草稿 token(本质是 extend 整个序列)
总结
| 术语 | 含义 |
|---|---|
extend | 对请求的 未缓存部分 进行首次前向计算(= prefill) |
| 目的 | 填充 KV Cache,计算 logits,为 decode 阶段做准备 |
| 特点 | 输入长度可变、计算量大、不支持 CUDA Graph |
| 触发时机 | 请求首次进入、chunked prefill 的每个 chunk |
理解 extend 模式,是掌握 LLM 推理中 prefill/decode 分离优化 的关键。
init_next_round_input 是 SGLang 中 Req(请求)类的一个关键方法,它的核心作用是:
为下一轮调度/前向计算准备输入状态,包括更新完整 token 序列(
fill_ids)、匹配前缀缓存(prefix_indices),并计算需要新处理的 token 数量(extend_input_len)。
一、函数签名与上下文
def init_next_round_input(self,tree_cache: Optional[BasePrefixCache] = None,enable_hierarchical_cache=False,
):
- 调用时机:每次请求即将进入 prefill 或 decode 阶段前,由调度器(
Scheduler)调用。 - 目的:让请求对象“知道”自己当前的状态,以便调度器决定如何分配内存、构建 batch。
二、逐行详解
1. 重置 placeholder embeddings(多模态支持)
self.input_placeholder_embs = self.original_input_placeholder_embs
- 如果请求包含多模态 placeholder(如
<image>),恢复原始嵌入,避免上一轮修改污染。
2. 构建完整 token 序列 fill_ids
self.fill_ids = self.origin_input_ids + self.output_ids
origin_input_ids:用户原始输入 token(如 prompt)output_ids:模型已生成的 tokenfill_ids= 完整上下文 =[prompt tokens] + [generated tokens]
✅ 这是当前请求的 完整逻辑序列,用于后续 KV Cache 匹配和 extend 计算。
3. 前缀缓存匹配(Radix/Chunk/HiRadix Cache)
if tree_cache is not None:if enable_hierarchical_cache:self.prefix_indices, self.last_node, self.last_node_global = (tree_cache.match_prefix(key=self.adjust_max_prefix_ids(), include_evicted=True))else:self.prefix_indices, self.last_node = tree_cache.match_prefix(rid=self.rid, key=self.adjust_max_prefix_ids())
关键点:
tree_cache.match_prefix(...):在前缀缓存中查找fill_ids的最长匹配前缀。- 返回:
prefix_indices:匹配到的 token 在 KV Cache 中的物理 slot 索引列表(如[100,101,102])last_node:缓存树中对应的节点(用于后续 evict/lock)
✅ 这一步实现了 KV Cache 前缀共享,避免重复 prefill。
4. 处理 Hierarchical Cache 的特殊情况
elif enable_hierarchical_cache:while self.last_node.evicted:# 如果 last_node 被驱逐,回退 prefix_indicesself.prefix_indices = self.prefix_indices[:-len(self.last_node.host_value)]self.last_node = self.last_node.parent
- 在
HiRadixCache中,部分节点可能被换出到 CPU。 - 如果
last_node已被 evict,则逐步回退,直到找到未 evict 的节点。
5. 计算需要新处理的 token 数(核心输出)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
extend_input_len= 需要新计算 KV Cache 的 token 数- 例如:
fill_ids = [1,2,3,4,5](总长 5)prefix_indices = [100,101](前 2 个已缓存)- →
extend_input_len = 3(需处理 token 3,4,5)
✅ 调度器将根据此值决定分配多少
out_cache_loc(KV Cache slot)。
三、辅助方法:adjust_max_prefix_ids()
def adjust_max_prefix_ids(self):self.fill_ids = self.origin_input_ids + self.output_idsinput_len = len(self.fill_ids)max_prefix_len = input_len - 1 # 至少留 1 个 token 用于生成 logitsif self.return_logprob:max_prefix_len = min(max_prefix_len, self.logprob_start_len)return self.fill_ids[:max_prefix_len]
- 作用:限制前缀匹配的最大长度,避免影响 logprob 计算或生成。
- 关键逻辑:
- 如果需要返回 logprob,前缀不能超过
logprob_start_len。 - 总是保留至少 1 个 token 用于当前 step 的 logits 计算。
- 如果需要返回 logprob,前缀不能超过
四、典型工作流程示例
场景:用户输入 "Hello",模型已生成 " world"
-
初始状态:
origin_input_ids = [100, 101](“Hello”)output_ids = [200, 201](" world")fill_ids = [100,101,200,201]
-
调用
init_next_round_input():tree_cache.match_prefix([100,101,200,201])→ 返回prefix_indices = [50,51,52](前 3 个已缓存)extend_input_len = 4 - 3 = 1
-
调度器行为:
- 分配 1 个新 KV slot(如
out_cache_loc = [53]) - 构建
input_ids = [201](最后一个 token 作为 decode 输入) - 写入
req_to_token_pool[req_idx, 3] = 53
- 分配 1 个新 KV slot(如
五、为什么这个函数重要?
- 实现前缀共享:大幅减少重复 prefill,提升吞吐。
- 支持 chunked prefill:大 prompt 分块处理时,每块都调用此函数。
- 支持 retract:OOM 时回退请求后,重新调用此函数恢复状态。
- logprob 兼容:确保 logprob 计算区域不被错误缓存。
总结
| 步骤 | 作用 |
|---|---|
构建 fill_ids | 获取完整上下文 |
匹配 prefix_indices | 查找已缓存的前缀 |
计算 extend_input_len | 确定需要新处理的 token 数 |
| 处理 hierarchical cache | 支持 CPU/GPU 混合缓存 |
init_next_round_input 是 SGLang 高效调度与缓存机制的核心枢纽,它让每个请求“自知状态”,从而实现灵活、高性能的批处理。
这里明确说是为了绕过logprob计算中的bug,所以暂时确保每个请求至少有一个token(不被缓存?)。也就是说,我们最多只能缓存到倒数第二个token,最后一个token不被缓存。这样,在扩展阶段,我们就会处理最后一个token(以及之后可能的新token)。但是,为什么这样做能绕过logprob计算的bug呢?我们不得而知。
总结
adjust_max_prefix_ids函数的作用是确定可以用于前缀缓存的最大token序列。它返回一个子序列,这个子序列是当前完整序列(输入+输出)的前缀,并且这个前缀的长度最多为input_len-1,即至少保留一个token不被缓存。
这样做的原因可能是由于logprob计算中的bug,暂时采用这种策略。在未来,这个限制可能会被放宽。
举例
假设当前请求的fill_ids是[1,2,3,4,5](长度为5)。
那么,adjust_max_prefix_ids返回的前缀可能是[1,2,3,4](长度为4),即最多缓存前4个token,第5个token不被缓存。
在接下来的扩展阶段,我们将处理第5个token(以及可能的新token)。
很好问题!我来详细解释adjust_max_prefix_ids的作用,特别是"确保至少1个token用于生成"的含义。
1. 函数完整代码
def adjust_max_prefix_ids(self):self.fill_ids = self.origin_input_ids + self.output_idsinput_len = len(self.fill_ids)# FIXME: To work around some bugs in logprob computation, we need to ensure each# request has at least one token. Later, we can relax this requirement and use `input_len`.max_prefix_len = input_len - 1if self.sampling_params.max_new_tokens > 0:# Need at least one token to compute logitsmax_prefix_len = min(max_prefix_len, input_len - 1)if self.return_logprob:max_prefix_len = min(max_prefix_len, self.logprob_start_len)max_prefix_len = max(max_prefix_len, 0)return self.fill_ids[:max_prefix_len]
2. "确保至少1个token用于生成"的含义
2.1 LLM生成的基本原理
在LLM中,生成下一个token需要:
- 输入: 前N个token
- 输出: 第N+1个token的概率分布
# 生成过程示例
输入: [1, 2, 3, 4] # "The cat sat on"
输出: [5] # "the" 的概率分布# 如果所有token都被缓存为前缀,就没有token用于生成了
2.2 具体例子说明
例子1:正常生成情况
# 当前序列
fill_ids = [1, 2, 3, 4, 5] # "The cat sat on the"
input_len = 5# 调整后的最大前缀长度
max_prefix_len = input_len - 1 # = 4# 这意味着:
# 前缀缓存最多可以缓存: [1, 2, 3, 4] # "The cat sat on"
# 必须保留: [5] # "the" 用于生成下一个token
例子2:如果缓存所有token会怎样
# 错误情况:缓存所有token
fill_ids = [1, 2, 3, 4, 5] # 完整序列
max_prefix_len = 5 # 错误!缓存了所有token# 问题:没有token用于计算下一个token的概率
# 输入序列: [] # 空,因为没有未缓存的token
# 无法生成下一个token!
3. 函数执行步骤详解
3.1 基础限制
max_prefix_len = input_len - 1 # 核心限制:保留最后1个token
为什么是input_len - 1?
- 序列有
input_len个token - 需要至少1个token作为生成下一个token的输入
- 所以最多只能缓存
input_len - 1个token
3.2 考虑生成需求
if self.sampling_params.max_new_tokens > 0:# Need at least one token to compute logitsmax_prefix_len = min(max_prefix_len, input_len - 1)
作用:如果还需要生成新token,必须确保有token用于计算logits。
3.3 考虑logprob计算
if self.return_logprob:max_prefix_len = min(max_prefix_len, self.logprob_start_len)
logprob计算的特殊要求:
- logprob计算需要知道每个token的前一个token
logprob_start_len指定从哪个位置开始计算logprob- 不能缓存超过这个位置的token
3.4 边界保护
max_prefix_len = max(max_prefix_len, 0) # 确保非负
4. 具体场景分析
场景1:新请求的预填充
# 新请求,没有输出
req.origin_input_ids = [1, 2, 3, 4] # "What is AI"
req.output_ids = [] # 无输出
req.fill_ids = [1, 2, 3, 4]
input_len = 4# adjust_max_prefix_ids 过程:
max_prefix_len = 4 - 1 = 3 # 保留最后1个token
# 返回: [1, 2, 3] # 最多缓存前3个token
为什么这样设计?
- 缓存前3个token的KV值:[1, 2, 3] → “What is AI”
- 使用第4个token “AI” 生成第一个输出token
场景2:生成过程中的调整
# 生成过程中
req.origin_input_ids = [1, 2, 3] # "Hello world"
req.output_ids = [4, 5] # "How are"
req.fill_ids = [1, 2, 3, 4, 5]
input_len = 5# adjust_max_prefix_ids 过程:
max_prefix_len = 5 - 1 = 4 # 保留最后1个token
# 返回: [1, 2, 3, 4] # 最多缓存前4个token
作用:
- 缓存:[1, 2, 3, 4] → “Hello world How are”
- 使用:[5] → “are” 生成下一个token “you”
场景3:logprob计算的影响
# 需要计算logprob的情况
req.origin_input_ids = [1, 2, 3, 4, 5, 6] # 长序列
req.output_ids = [7, 8]
req.fill_ids = [1, 2, 3, 4, 5, 6, 7, 8]
input_len = 8
req.return_logprob = True
req.logprob_start_len = 5 # 从第5个token开始计算logprob# adjust_max_prefix_ids 过程:
max_prefix_len = 8 - 1 = 7 # 基础限制
max_prefix_len = min(7, 5) = 5 # logprob限制更严格
# 返回: [1, 2, 3, 4, 5] # 最多缓存到logprob_start_len
为什么logprob需要这个限制?
# logprob计算示例
序列: [1, 2, 3, 4, 5, 6, 7, 8]
logprob_start_len = 5# 需要计算token 6,7,8的logprob
# 但计算token6的logprob需要token5的隐藏状态
# 如果token5被缓存了,就无法计算token6的logprob
5. 技术原理深度解析
5.1 Transformer的生成机制
# Transformer生成下一个token的过程
def generate_next_token(sequence):# 输入: 整个序列的token [t1, t2, ..., tn]# 输出: 下一个token tn+1的概率分布# 1. 计算序列的隐藏状态hidden_states = transformer(sequence) # 形状: [n, hidden_size]# 2. 只使用最后一个token的隐藏状态预测下一个tokenlast_hidden = hidden_states[-1] # 形状: [hidden_size]next_token_logits = lm_head(last_hidden) # 形状: [vocab_size]return next_token_logits
关键点:生成只需要最后一个token的隐藏状态,但计算这个隐藏状态需要所有前序token的KV缓存。
5.2 前缀缓存与生成的协调
# 有前缀缓存时的生成过程
def generate_with_prefix_cache(sequence, prefix_indices):# sequence: 完整序列 [1, 2, 3, 4, 5]# prefix_indices: 已缓存的KV位置 [100, 101, 102] (对应token 1,2,3)# 1. 从缓存中读取前3个token的KVk_cache = get_k_cache(prefix_indices) # 形状: [3, heads, dim]v_cache = get_v_cache(prefix_indices)# 2. 只计算后2个token的KV并生成new_tokens = sequence[3:] # [4, 5]new_k, new_v = compute_kv(new_tokens)# 3. 合并KV缓存进行注意力计算full_k = concat(k_cache, new_k)full_v = concat(v_cache, new_v)# 4. 使用最后一个token生成last_hidden = attention_with_kv_cache(full_k, full_v, sequence)next_token_logits = lm_head(last_hidden[-1])return next_token_logits
6. 实际应用中的影响
6.1 对性能的影响
# 序列长度: 1000个token
# 情况A: 缓存999个token,计算1个新token
计算量 ≈ O(1²) # 主要计算1个新token# 情况B: 缓存0个token,计算1000个token
计算量 ≈ O(1000²) # 计算整个序列# 性能差异: 1000000倍!
6.2 对内存的影响
# KV缓存内存占用
缓存999个token: 需要999个KV位置
缓存0个token: 需要0个KV位置(但计算代价高)# 权衡: 用内存换计算速度
7. 特殊情况处理
7.1 极短序列
# 只有1个token的情况
req.fill_ids = [1] # 单个token
input_len = 1# adjust_max_prefix_ids:
max_prefix_len = 1 - 1 = 0 # 不能缓存任何token
# 返回: [] # 空列表
处理:必须计算整个序列,因为没有token可用于生成。
7.2 边界条件保护
max_prefix_len = max(max_prefix_len, 0) # 确保非负
作用:防止在空序列或极短序列时出现负数。
8. 总结
adjust_max_prefix_ids的核心作用是:
- 生成保障:确保至少1个token可用于生成下一个token
- 性能优化:最大化前缀缓存的使用,减少计算量
- 功能协调:平衡生成需求与logprob计算需求
- 边界保护:处理各种极端序列长度情况
关键理解:在LLM推理中,我们不能把所有token都缓存为前缀,必须保留至少一个token作为"种子"来生成后续内容。这个函数就是智能地确定可以安全缓存的最大前缀长度,在性能和功能之间找到最佳平衡点。
在你提供的 SGLang 代码中,并没有一个名为 alloc_token_slot 的独立函数,但有一个非常关键且功能完全对应的方法:
ScheduleBatch.alloc_token_slots(num_tokens: int, backup_state: bool = False)
这是你实际想了解的函数。下面我将详细讲解它的 作用、实现逻辑、调用时机和使用方式。
一、函数签名与目的
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
功能:
为
num_tokens个新 token 分配 KV Cache 中的物理存储位置(slot indices)。
- 返回一个
torch.Tensor,shape[num_tokens],每个元素是全局 token slot ID(如[100, 101, 102])。 - 如果内存不足,会尝试从 前缀缓存(Radix/Chunk Cache) 中 驱逐(evict) 可回收的请求以释放空间。
- 如果仍不足,抛出
RuntimeError(OOM)。
二、内部实现逻辑(逐行解析)
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):# 1. 检查内存是否足够;若不足,尝试从 tree_cache 驱逐if self.token_to_kv_pool_allocator.available_size() < num_tokens:if self.tree_cache is not None:self.tree_cache.evict(num_tokens) # 释放至少 num_tokens 个 slots# 2. (可选)备份当前空闲列表状态,用于失败回滚if backup_state:state = self.token_to_kv_pool_allocator.backup_state()# 3. 执行实际分配out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)# 4. 处理分配失败if out_cache_loc is None:phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"error_msg = f"{phase_str} out of memory. ..."logger.error(error_msg)raise RuntimeError(error_msg)# 5. 返回结果if backup_state:return out_cache_loc, stateelse:return out_cache_loc
三、核心组件协作
| 组件 | 作用 |
|---|---|
TokenToKVPoolAllocator | 管理空闲 slot 列表(free_slots),提供 alloc() / free() 接口 |
tree_cache(Radix/Chunk/HiRadix Cache) | 前缀缓存;当内存不足时,evict(N) 会释放至少 N 个 token 的 KV Cache |
out_cache_loc | 分配结果:token → KV Cache slot 的映射 |
✅ 设计亮点:通过
evict实现 弹性内存管理,避免直接拒绝新请求。
四、调用时机与使用场景
1. Prefill 阶段(prepare_for_extend)
# 计算需要 extend 的总 token 数
extend_num_tokens = sum(len(ids) for ids in input_ids)# 分配 slots
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
- 用于为 新请求的 prompt 或 chunked prefill 的 chunk 分配 KV Cache。
2. Decode 阶段(prepare_for_decode)
# 每个请求生成 1 个新 token,共 bs 个
self.out_cache_loc = self.alloc_token_slots(bs)
- 用于为 自回归生成 分配新 token 的 KV Cache。
3. Retract 场景(内存不足时回退请求)
- 虽然
retract_decode主要调用free(),但后续重新调度时会再次调用alloc_token_slots。
五、返回值 out_cache_loc 的用途
分配得到的 out_cache_loc 会被用于:
-
写入
ReqToTokenPool(建立请求 ↔ token 位置映射):req_to_token_pool.write((req_idx, pos_slice), out_cache_loc[...]) -
构建
ModelWorkerBatch→ForwardBatch:ForwardBatch(out_cache_loc=out_cache_loc, ...) -
模型前向时写入 KV Cache:
kv_cache.set_kv_buffer(loc=out_cache_loc, cache_k=..., cache_v=...)
✅ 它是 调度器与模型执行器之间的“内存契约”。
六、与分页(Paged KV Cache)的关系
- 当
page_size == 1(默认):使用alloc_token_slots - 当
page_size > 1:使用alloc_paged_token_slots_extend/alloc_paged_token_slots_decode- 因为分页模式下需处理 非连续 pages、续页 等复杂逻辑
七、示例
假设:
- 请求 A 需要 extend 3 个 token
- 请求 B 需要 extend 2 个 token
num_tokens = 5
调用:
out_cache_loc = batch.alloc_token_slots(5)
# 假设返回: tensor([100, 101, 102, 103, 104])
后续:
- 请求 A 的 token → slots
[100, 101, 102] - 请求 B 的 token → slots
[103, 104]
八、注意事项
- 仅用于
page_size == 1 - OOM 时会自动 evict,但可能影响缓存命中率
backup_state=True用于安全分配(失败时回滚)- 返回的是 GPU tensor(
.to(device, non_blocking=True)在上层完成)
总结
| 特性 | 说明 |
|---|---|
| 作用 | 为新 token 分配 KV Cache 位置 |
| 调用方 | ScheduleBatch.prepare_for_extend/decode |
| 依赖 | TokenToKVPoolAllocator + RadixCache(evict) |
| 关键机制 | 内存不足时 evict、失败回滚 |
| 输出 | out_cache_loc(GPU tensor of slot indices) |
| 适用模式 | 仅 page_size == 1(非分页) |
alloc_token_slots 是 SGLang 实现 高效、安全、弹性 KV Cache 管理 的核心环节。
mix_with_running 这个函数干啥的,举例说明
def mix_with_running(self, running_batch: “ScheduleBatch”):
self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size()
for req in running_batch.reqs:req.fill_ids = req.origin_input_ids + req.output_idsreq.extend_input_len = 1input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend([len(r.origin_input_ids) + len(r.output_ids) + deltafor r in running_batch.reqs]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)
函数作用
mix_with_running函数的作用是将当前批次(通常是一个预填充批次)与一个正在运行的解码批次合并,形成一个混合批次。这种混合批次可以同时处理新请求的预填充和已运行请求的解码。
使用场景
在流式推理服务中,经常会有新的请求到来(需要预填充)和已有的请求在生成(解码步骤)。为了充分利用GPU,我们可以将预填充和解码一起在一个批次中处理,这就是混合批次。
详细步骤
设置前向模式为混合模式:self.forward_mode = ForwardMode.MIXED
更新运行中批次的每个请求:
将每个请求的fill_ids更新为原始输入ID加上输出ID(即完整的当前序列)
设置每个请求的extend_input_len为1(因为解码阶段每次只生成一个token)
合并输入ID和输出缓存位置:
将当前批次的input_ids和运行中批次的input_ids拼接
将当前批次的out_cache_loc和运行中批次的out_cache_loc拼接
合并两个批次:调用merge_batch方法将运行中批次的请求合并到当前批次
更新前缀长度、扩展长度等元数据:
对于运行中批次的每个请求,计算其前缀长度(即已经缓存了的token数量)。注意这里有一个delta,在重叠调度模式下为0,非重叠为-1。
扩展长度(extend_lens)每个运行中请求都是1(因为每个请求只生成一个token)
扩展token数量(extend_num_tokens)增加运行中批次的大小
扩展logprob起始长度(extend_logprob_start_lens)对于运行中批次的每个请求设置为0
举例说明
假设我们有两个批次:
当前批次(预填充批次):包含2个新请求
请求A: 输入长度为10,需要预填充10个token
请求B: 输入长度为15,需要预填充15个token
运行中批次(解码批次):包含3个正在生成的请求
请求C: 已经生成了5个token,原始输入长度为8,当前序列长度为8+5=13
请求D: 已经生成了10个token,原始输入长度为20,当前序列长度为30
请求E: 已经生成了3个token,原始输入长度为12,当前序列长度为15
合并过程
设置模式:当前批次变为混合模式。
更新运行中批次的请求:
对于请求C、D、E,设置fill_ids为完整的当前序列(原始输入+输出),并设置extend_input_len=1。
合并张量:
当前批次的input_ids(长度为10+15=25)与运行中批次的input_ids(长度为3,因为每个请求一个token,即上次生成的token)拼接,得到长度为28的input_ids。
类似地,out_cache_loc也会合并。
合并请求:将请求C、D、E加入到当前批次的请求列表中,现在当前批次有5个请求。
更新元数据:
前缀长度(prefix_lens):
对于预填充的请求A和B,已经在prepare_for_extend中设置好了(比如可能利用了前缀缓存,假设请求A的前缀长度为0,请求B的前缀长度为0,即全部需要计算)
现在加入运行中批次的前缀长度:对于每个运行中请求,前缀长度是len(origin_input_ids) + len(output_ids) + delta。注意,这里的前缀长度是指已经缓存的token数量。在非重叠模式下(delta=-1),因为最后一个token还没有缓存,所以前缀长度是当前序列长度减1;在重叠模式下,可能是当前序列长度(因为重叠调度可能已经缓存了最后一个token?)。代码中的注释说明“我们不缓存每个解码步骤”,所以这里可能需要仔细设计。实际上,在解码阶段,每次生成一个token,我们会缓存这个新生成的token,但这里的前缀长度指的是在本次扩展之前已经缓存的数量。所以对于运行中请求,在本次扩展(生成下一个token)之前,已经缓存了len(origin_input_ids) + len(output_ids) - 1(非重叠)或len(origin_input_ids) + len(output_ids)(重叠)?这个根据调度策略有所不同。
扩展长度(extend_lens):预填充的请求A和B分别是10和15,运行中批次每个请求都是1,所以扩展长度列表变为[10, 15, 1, 1, 1]。
扩展token总数(extend_num_tokens)变为25(预填充)+3(解码)=28。
扩展logprob起始长度(extend_logprob_start_lens)对于运行中批次请求设置为0。
注意点
这个函数用于混合预填充和解码,以提高GPU利用率。
重叠调度和非重叠调度在计算前缀长度时有差异,通过delta调整。
合并后,模型前向传播会同时处理预填充token(多个)和解码token(每个请求1个)。
总结
mix_with_running实现了预填充和解码的混合批次处理,是提高推理效率的关键技术之一。它允许推理引擎同时处理新请求的初始化和老请求的继续生成,从而充分利用GPU计算资源。
我来详细解析mix_with_running这个函数,它实现了混合批次处理的重要功能。
1. 函数完整代码
def mix_with_running(self, running_batch: "ScheduleBatch"):self.forward_mode = ForwardMode.MIXEDrunning_bs = running_batch.batch_size()for req in running_batch.reqs:req.fill_ids = req.origin_input_ids + req.output_idsreq.extend_input_len = 1input_ids = torch.cat([self.input_ids, running_batch.input_ids])out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])self.merge_batch(running_batch)self.input_ids = input_idsself.out_cache_loc = out_cache_loc# For overlap scheduler, the output_ids has one step delaydelta = 0 if self.enable_overlap else -1# NOTE: prefix_indices is what has been cached, but we don't cache each decode stepself.prefix_lens.extend([len(r.origin_input_ids) + len(r.output_ids) + deltafor r in running_batch.reqs])self.extend_lens.extend([1] * running_bs)self.extend_num_tokens += running_bs# TODO (lianmin): Revisit this. It should be seq_len - 1self.extend_logprob_start_lens.extend([0] * running_bs)
2. 函数核心作用
主要功能:将预填充批次与正在运行的解码批次合并,形成一个混合批次,在一次前向传播中同时处理新请求的预填充和已有请求的解码。
3. 使用场景和动机
3.1 为什么要混合批次?
在LLM推理服务中,通常有两种类型的请求:
- 新请求:需要预填充(Prefill)处理整个输入序列
- 运行中请求:需要解码(Decode)生成下一个token
传统方式:分别处理,效率低
时间线:
[预填充批次] → [空闲] → [解码批次] → [空闲] → [预填充批次] ...
混合方式:同时处理,提高GPU利用率
时间线:
[混合批次:预填充 + 解码] → [混合批次:预填充 + 解码] → ...
3.2 性能优势
# 分开处理的GPU利用率
预填充: 100% GPU利用率 → 解码: 100% GPU利用率 → 空闲: 0% GPU利用率# 混合处理的GPU利用率
混合: 100% GPU利用率 → 混合: 100% GPU利用率 → 混合: 100% GPU利用率
4. 详细执行步骤
4.1 设置前向模式
self.forward_mode = ForwardMode.MIXED
作用:标记当前批次为混合模式,模型需要同时处理预填充和解码。
4.2 准备运行中批次的请求
running_bs = running_batch.batch_size()for req in running_batch.reqs:req.fill_ids = req.origin_input_ids + req.output_idsreq.extend_input_len = 1
关键操作:
- 重建完整序列:
fill_ids = origin_input_ids + output_ids - 设置扩展长度:
extend_input_len = 1(解码每次只生成1个token)
示例:
# 运行中请求的状态
req.origin_input_ids = [1, 2, 3] # "Hello world"
req.output_ids = [4, 5, 6] # "How are you"
# 处理后:
req.fill_ids = [1, 2, 3, 4, 5, 6] # 完整序列
req.extend_input_len = 1 # 每次解码只处理1个新token
4.3 合并输入数据
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
数据合并:
input_ids:预填充token + 解码tokenout_cache_loc:预填充KV位置 + 解码KV位置
4.4 合并批次元数据
self.merge_batch(running_batch)
调用之前分析过的merge_batch方法,合并所有请求和采样信息。
4.5 计算前缀长度(关键逻辑)
# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend([len(r.origin_input_ids) + len(r.output_ids) + deltafor r in running_batch.reqs]
)
前缀长度计算逻辑:
- 重叠调度(
delta=0):前缀包含所有已生成token - 非重叠调度(
delta=-1):前缀不包含最后一个token(用于生成)
具体示例:
# 运行中请求
req.origin_input_ids = [1, 2, 3] # 3个token
req.output_ids = [4, 5, 6] # 3个已生成token
当前序列长度 = 6# 非重叠调度 (delta = -1)
prefix_len = 3 + 3 - 1 = 5 # 缓存前5个token,第6个用于生成# 重叠调度 (delta = 0)
prefix_len = 3 + 3 + 0 = 6 # 缓存所有6个token
4.6 设置扩展参数
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
self.extend_logprob_start_lens.extend([0] * running_bs)
参数含义:
extend_lens:运行中请求每个扩展1个tokenextend_num_tokens:总扩展token数增加extend_logprob_start_lens:运行中请求不计算输入logprob
5. 具体示例分析
5.1 场景设置
假设我们有:
- 预填充批次:2个新请求
- 运行中批次:3个解码请求
预填充批次状态
# 请求A(新)
reqA.origin_input_ids = [10, 11, 12, 13] # "What is AI"
reqA.output_ids = [] # 无输出
reqA.fill_ids = [10, 11, 12, 13]
reqA.extend_input_len = 4# 请求B(新)
reqB.origin_input_ids = [20, 21, 22] # "Hello world"
reqB.output_ids = [] # 无输出
reqB.fill_ids = [20, 21, 22]
reqB.extend_input_len = 3# 批次数据
self.input_ids = [10, 11, 12, 13, 20, 21, 22] # 7个token
self.out_cache_loc = [100, 101, 102, 103, 104, 105, 106] # 7个位置
运行中批次状态
# 请求C(运行中)
reqC.origin_input_ids = [30, 31] # "The weather"
reqC.output_ids = [32, 33] # "is nice"
reqC.fill_ids = [30, 31, 32, 33] # 需要更新
reqC.extend_input_len = 1 # 需要设置# 请求D(运行中)
reqD.origin_input_ids = [40, 41, 42] # "Machine learning"
reqD.output_ids = [43] # "is"
reqD.fill_ids = [40, 41, 42, 43] # 需要更新
reqD.extend_input_len = 1 # 需要设置# 请求E(运行中)
reqE.origin_input_ids = [50] # "AI"
reqE.output_ids = [51, 52, 53] # "will change"
reqE.fill_ids = [50, 51, 52, 53] # 需要更新
reqE.extend_input_len = 1 # 需要设置# 批次数据
running_batch.input_ids = [33, 43, 53] # 每个请求的上次输出token
running_batch.out_cache_loc = [200, 201, 202] # 为这次解码分配的位置
5.2 混合过程执行
步骤1:更新运行中请求
for req in running_batch.reqs:req.fill_ids = req.origin_input_ids + req.output_idsreq.extend_input_len = 1
更新后:
请求C: fill_ids = [30, 31, 32, 33], extend_input_len = 1
请求D: fill_ids = [40, 41, 42, 43], extend_input_len = 1
请求E: fill_ids = [50, 51, 52, 53], extend_input_len = 1
步骤2:合并张量数据
input_ids = torch.cat([[10,11,12,13,20,21,22], [33,43,53]])
# 结果: [10,11,12,13,20,21,22,33,43,53] # 10个tokenout_cache_loc = torch.cat([[100,101,102,103,104,105,106], [200,201,202]])
# 结果: [100,101,102,103,104,105,106,200,201,202] # 10个位置
步骤3:合并批次
self.merge_batch(running_batch)
# 现在 self.reqs = [reqA, reqB, reqC, reqD, reqE] # 5个请求
步骤4:计算前缀长度(假设非重叠调度)
delta = -1 # 非重叠调度# 预填充批次的前缀长度(在prepare_for_extend中已设置)
# 假设: self.prefix_lens = [0, 0] # 请求A和B没有前缀缓存# 添加运行中请求的前缀长度
self.prefix_lens.extend([len(reqC.origin_input_ids) + len(reqC.output_ids) + (-1), # 2+2-1=3len(reqD.origin_input_ids) + len(reqD.output_ids) + (-1), # 3+1-1=3 len(reqE.origin_input_ids) + len(reqE.output_ids) + (-1) # 1+3-1=3
])
# 结果: prefix_lens = [0, 0, 3, 3, 3]
步骤5:设置扩展参数
self.extend_lens.extend([1, 1, 1]) # 原来[4,3] + [1,1,1] = [4,3,1,1,1]
self.extend_num_tokens += 3 # 原来7 + 3 = 10
self.extend_logprob_start_lens.extend([0, 0, 0]) # 运行中请求不计算输入logprob
5.3 最终混合批次状态
# 混合批次最终状态
forward_mode = ForwardMode.MIXED
reqs = [reqA, reqB, reqC, reqD, reqE] # 5个请求# 输入数据
input_ids = [10,11,12,13,20,21,22,33,43,53] # 10个token
out_cache_loc = [100,101,102,103,104,105,106,200,201,202] # 10个位置# 扩展参数
prefix_lens = [0, 0, 3, 3, 3] # 每个请求的前缀长度
extend_lens = [4, 3, 1, 1, 1] # 每个请求需要扩展的token数
extend_num_tokens = 10 # 总扩展token数
6. 模型前向传播时的处理
6.1 混合模式下的注意力计算
在混合模式下,模型需要同时处理:
- 预填充请求:计算多个token的注意力(全序列)
- 解码请求:计算单个token的注意力(增量)
def mixed_attention_forward(input_ids, prefix_lens, extend_lens):# 输入: [10,11,12,13,20,21,22,33,43,53]# prefix_lens: [0,0,3,3,3]# extend_lens: [4,3,1,1,1]# 请求A: 处理token [10,11,12,13] (4个token,无前缀)# 请求B: 处理token [20,21,22] (3个token,无前缀) # 请求C: 处理token [33] (1个token,前缀3个)# 请求D: 处理token [43] (1个token,前缀3个)# 请求E: 处理token [53] (1个token,前缀3个)# 实际实现使用 RaggedTensor 或类似结构处理不规则序列
6.2 性能优势体现
# 分开处理的GPU使用
预填充: 处理7个token → GPU利用率: 70%
解码: 处理3个token → GPU利用率: 30%# 混合处理的GPU使用
混合: 处理10个token → GPU利用率: 100%
7. 调度策略的影响
7.1 重叠调度 vs 非重叠调度
delta = 0 if self.enable_overlap else -1
重叠调度(delta=0):
- 前缀包含最后一个token
- 解码步骤可以提前准备
- 更高的吞吐量,但可能增加延迟
非重叠调度(delta=-1):
- 前缀不包含最后一个token
- 更简单的实现
- 较低的吞吐量,但延迟更可预测
7.2 实际调度示例
# 时间线示例
时间点 t0: [预填充批次] 开始
时间点 t1: [运行中批次] 准备好
时间点 t2: mix_with_running() 创建混合批次
时间点 t3: [混合批次] 执行前向传播# 重叠调度可以在t1就开始准备运行中批次
# 非重叠调度需要等待t0完成
8. 总结
mix_with_running函数的核心价值:
- 提高GPU利用率:通过合并预填充和解码请求,减少GPU空闲时间
- 降低延迟:新请求不用等待当前解码完成
- 提高吞吐量:一次性处理更多样化的计算负载
- 灵活调度:支持重叠和非重叠两种调度策略
关键理解:混合批次处理是现代LLM推理服务中的关键技术,它通过将计算密集的预填充和内存带宽受限的解码组合在一起,实现了更好的硬件资源利用,显著提升了整体推理性能。
