【vLLM 学习】Multilora Inference
vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。
更多 vLLM 中文文档及教程可访问 →vllm.hyper.ai/
*在线运行 vLLM 入门教程:零基础分步指南
源码 examples/offline_inference/multilora_inference.py
# SPDX-License-Identifier: Apache-2.0"""
此示例显示了如何使用多路线功能
用于离线推理。
需要 HuggingFace 凭证才能访问 Llama2。
"""from typing import Optionalfrom huggingface_hub import snapshot_downloadfrom vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequestdef create_test_prompts(lora_path: str
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:"""创建包含采样参数的测试提示列表。为基准模型创建 2 个请求,为 LoRA 创建 4 个请求。我们定义了两个不同的 LoRA 适配器(出于演示目的使用相同模型)。由于我们同时设置了 max_loras=1,预计使用第二个 LoRA 适配器的请求将在所有使用第一个适配器的请求完成后运行。"""return [("A robot may not injure a human being",SamplingParams(temperature=0.0,logprobs=1,prompt_logprobs=1,max_tokens=128), None),("To be or not to be,",SamplingParams(temperature=0.8,top_k=5,presence_penalty=0.2,max_tokens=128), None),("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501SamplingParams(temperature=0.0,logprobs=1,prompt_logprobs=1,max_tokens=128,stop_token_ids=[32003]),LoRARequest("sql-lora", 1, lora_path)),("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501SamplingParams(temperature=0.0,logprobs=1,prompt_logprobs=1,max_tokens=128,stop_token_ids=[32003]),LoRARequest("sql-lora2", 2, lora_path)),]def process_requests(engine: LLMEngine,test_prompts: list[tuple[str, SamplingParams,Optional[LoRARequest]]]):"持续处理提示列表并处理输出"request_id = 0while test_prompts or engine.has_unfinished_requests():if test_prompts:prompt, sampling_params, lora_request = test_prompts.pop(0)engine.add_request(str(request_id),prompt,sampling_params,lora_request=lora_request)request_id += 1request_outputs: list[RequestOutput] = engine.step()for request_output in request_outputs:if request_output.finished:print(request_output)def initialize_engine() -> LLMEngine:"""Initialize the LLMEngine."""# max_loras:控制可以在同一批中使用的 LoRA 的数量。# 较大的值将导致更高的内存使用情况# 因为每个 LoRA 插槽需要其自己的前置张量。# max_lora_rank:控制所有 LoRA 的最大支持 rank 。# 更大的值将导致更高的内存使用。如果您知道所有 LoRA 都会# 使用相同的 rank ,建议将其设置为尽可能低。# max_cpu_loras:控制 CPU LORA 缓存的大小。engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",enable_lora=True,max_loras=1,max_lora_rank=8,max_cpu_loras=2,max_num_seqs=256)return LLMEngine.from_engine_args(engine_args)def main():"""Main function that sets up and runs the prompt processing."""engine = initialize_engine()lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")test_prompts = create_test_prompts(lora_path)process_requests(engine, test_prompts)if __name__ == '__main__':main()