大模型问答原理解析
这里使用openai格式自行构造大模型的API服务,并进行调用
一、用户输入
openai的调用格式为:
messages = [ {"role": "system", "content": "你是一个乐于助人的 AI 助手。"}, {"role": "user", "content": "你好!请介绍一下你自己。"}]
其中role是prompt中的角色,一般包括system、user、assistant,content是该角色的提问或回答。
二、服务端处理
1、使用tokenizer将用户输入转为模板格式
<|im_start|>system
你是一个乐于助人的 AI 助手。<|im_end|>
<|im_start|>user
你好!请介绍一下你自己。<|im_end|>
<|im_start|>assistant
2、使用tokenizer将模板格式转为token ID 序列
{'input_ids': tensor([[151644, 8948, 198, 56568, 101909, 99350, 34204, 99262, 103947, 15235, 54599, 102, 44934, 1773, 151645, 198, 151644, 872, 198, 108386, 6313, 14880, 109432, 107828, 1773, 151645, 198, 151644, 77091, 198]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
其中input_ids是模型的输入 token ID 序列,attention_mask指示哪些 token 是真实内容(需要关注),哪些是填充(padding)。
3、大模型输出
tensor([[151644, 8948, 198, 56568, 101909, 99350, 34204, 99262, 103947,15235, 54599, 102, 44934, 1773, 151645, 198, 151644, 872,198, 108386, 6313, 14880, 109432, 107828, 1773, 151645, 198,151644, 77091, 198, 151667, 198, 99692, 3837, 20002, 104029,109432, 99283, 1773, 101140, 3837, 35946, 85106, 100692, 20002,104378, 1773, 99650, 87267, 32664, 15469, 110498, 102342, 99794,3837, 100631, 99172, 81167, 97611, 98380, 33108, 105795, 1773,100622, 99350, 34204, 99262, 103947, 15469, 110498, 3837, 35946,99730, 11622, 106098, 100136, 104715, 110098, 36407, 104493, 3407,104326, 3837, 104515, 103944, 102104, 102994, 99936, 27442, 5122,97611, 101294, 5373, 98380, 5373, 116541, 101034, 100007, 100364,20002, 1773, 85106, 101153, 37029, 102767, 99361, 105302, 116925,3837, 100662, 113113, 32108, 3837, 115987, 100047, 101128, 1773,91572, 3837, 30534, 102017, 97611, 100772, 3837, 101912, 42140,102064, 100143, 5373, 100134, 99788, 5373, 104913, 113272, 33108,113065, 3837, 99654, 20002, 26232, 101222, 97611, 100661, 3407,101948, 3837, 20002, 87267, 99880, 99392, 35946, 64471, 100006,54542, 100646, 88802, 3837, 101912, 102104, 86119, 5373, 104223,43815, 5373, 99553, 101898, 49567, 1773, 101886, 96050, 100157,15946, 85106, 104496, 100001, 99522, 3837, 101987, 97611, 110523,33071, 1773, 91572, 3837, 30534, 104125, 20002, 101080, 100398,86119, 3837, 105920, 35946, 105344, 99553, 100364, 3407, 104019,60533, 100166, 104542, 3837, 17177, 27442, 66394, 87267, 33126,110044, 1773, 77288, 20002, 87267, 99880, 102104, 110485, 3837,99999, 85106, 102243, 27369, 32757, 33108, 86744, 57553, 33071,1773, 100161, 3837, 23031, 97815, 99692, 80565, 72881, 50009,101143, 3837, 115987, 104048, 109010, 3837, 102167, 100642, 101069,8997, 151668, 271, 108386, 6313, 104198, 48, 16948, 3837,46944, 67071, 31935, 64559, 104800, 100048, 9370, 101951, 102064,104949, 1773, 35946, 100006, 100364, 56568, 102104, 100646, 86119,5373, 104223, 43815, 5373, 99553, 101898, 33108, 71817, 104913,113272, 1773, 97611, 100772, 100630, 48443, 16, 13, 3070,42140, 102064, 100143, 334, 5122, 35946, 100006, 101128, 62926,43959, 101312, 102064, 104597, 3837, 100630, 104811, 5373, 105205,5373, 24339, 72881, 5373, 105576, 72881, 49567, 8997, 17,13, 3070, 100134, 99788, 334, 5122, 35946, 100006, 67338,100722, 20074, 100134, 3837, 99607, 103983, 100005, 100032, 33108,101139, 8997, 18, 13, 3070, 104913, 113272, 334, 5122,35946, 100006, 71817, 106888, 104913, 113272, 33108, 104552, 100768,8997, 19, 13, 3070, 113065, 334, 5122, 35946, 100006,43959, 102343, 43815, 3837, 29524, 101108, 5373, 107604, 5373,109620, 49567, 8997, 20, 13, 3070, 115447, 334, 5122,35946, 100006, 100364, 56568, 100638, 99912, 86119, 3837, 29524,110569, 5373, 111540, 5373, 99424, 101898, 49567, 3407, 100783,56568, 104139, 86119, 57191, 85106, 100364, 3837, 104570, 102422,106525, 6313, 105351, 110121, 99553, 115404, 105427, 33108, 100364,1773, 144236, 151645]], device='cuda:0')
输出中还包括输入的 token ID 序列,比如
151644, 8948, 198, 56568, 101909, 99350, 34204, 99262, 103947,15235, 54599, 102, 44934, 1773, 151645, 198, 151644, 872,198, 108386, 6313, 14880, 109432, 107828, 1773, 151645, 198,151644, 77091, 198
所以需要提取新生成的 token ID 序列,并通过tokenizer进行解码
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
4、模型服务对结果进行包装并返回
return JSONResponse({ "id": "chatcmpl-" + str(int(time.time())), "object": "chat.completion", "created": int(time.time()), "model": request.model, "choices": [{ "message": {"role": "assistant", "content": response_text}, "index": 0, "finish_reason": "stop" }], "usage": { "prompt_tokens": inputs.input_ids.shape[1], "completion_tokens": len(tokenizer.encode(response_text)), "total_tokens": inputs.input_ids.shape[1] + len(tokenizer.encode(response_text)) } })
5、流式输出
if request.stream:# 流式响应streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)# 在子线程中生成generation_kwargs = dict(**inputs,streamer=streamer,**gen_kwargs)thread = Thread(target=model.generate, kwargs=generation_kwargs)thread.start()def stream_generator():for new_text in streamer:if new_text:chunk = {"id": "chatcmpl-" + str(int(time.time())),"object": "chat.completion.chunk","created": int(time.time()),"model": request.model,"choices": [{"delta": {"content": new_text},"index": 0,"finish_reason": None}]}# print("Chunk:", chunk)yield f"data: {json.dumps(chunk)}\n\n"# 发送结束信号final_chunk = {"id": "chatcmpl-" + str(int(time.time())),"object": "chat.completion.chunk","created": int(time.time()),"model": request.model,"choices": [{"delta": {},"index": 0,"finish_reason": "stop"}]}yield f"data: {json.dumps(final_chunk)}\n\n"yield "data: [DONE]\n\n"return StreamingResponse(stream_generator(), media_type="text/event-stream")
使用yield进行流式输出
三、源码
1、server.py
import os
import torch
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from pydantic import BaseModel
from typing import List, Optional, Union, Dict, Any
import time
import json
import uvicorn# ===== 配置 =====
MODEL_PATH = r"D:\modelbase\Qwen3-4B" # 替换为你的本地模型路径
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print("DEVICE:", torch.cuda.is_available())
# ===== 加载模型和 tokenizer =====
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,torch_dtype=TORCH_DTYPE,device_map="auto",trust_remote_code=True
)
model.eval()
print("Model loaded.")app = FastAPI(title="Qwen3 OpenAI-compatible API")# ===== 请求体定义(兼容 OpenAI 格式)=====
class ChatCompletionRequest(BaseModel):model: strmessages: List[Dict[str, str]]temperature: Optional[float] = 0.7top_p: Optional[float] = 0.95max_tokens: Optional[int] = 512stream: Optional[bool] = Falsestop: Optional[Union[str, List[str]]] = None# ===== 工具函数 =====
def apply_chat_template(messages):# Qwen 使用自己的 chat templatetext = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)return text# ===== 路由:/v1/chat/completions =====
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):# 构造 promptprint(f"Request messages: {request.messages}")prompt = apply_chat_template(request.messages)print("Prompt:", prompt)inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)print("Input:", inputs)# 生成参数gen_kwargs = {"temperature": request.temperature,"top_p": request.top_p,"max_new_tokens": request.max_tokens,"do_sample": request.temperature > 0,"pad_token_id": tokenizer.eos_token_id,}if request.stream:# 流式响应streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)# 在子线程中生成generation_kwargs = dict(**inputs,streamer=streamer,**gen_kwargs)thread = Thread(target=model.generate, kwargs=generation_kwargs)thread.start()def stream_generator():for new_text in streamer:if new_text:chunk = {"id": "chatcmpl-" + str(int(time.time())),"object": "chat.completion.chunk","created": int(time.time()),"model": request.model,"choices": [{"delta": {"content": new_text},"index": 0,"finish_reason": None}]}# print("Chunk:", chunk)yield f"data: {json.dumps(chunk)}\n\n"# 发送结束信号final_chunk = {"id": "chatcmpl-" + str(int(time.time())),"object": "chat.completion.chunk","created": int(time.time()),"model": request.model,"choices": [{"delta": {},"index": 0,"finish_reason": "stop"}]}yield f"data: {json.dumps(final_chunk)}\n\n"yield "data: [DONE]\n\n"return StreamingResponse(stream_generator(), media_type="text/event-stream")else:# 非流式with torch.no_grad():outputs = model.generate(**inputs, **gen_kwargs)print("Output:", outputs)response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)return JSONResponse({"id": "chatcmpl-" + str(int(time.time())),"object": "chat.completion","created": int(time.time()),"model": request.model,"choices": [{"message": {"role": "assistant", "content": response_text},"index": 0,"finish_reason": "stop"}],"usage": {"prompt_tokens": inputs.input_ids.shape[1],"completion_tokens": len(tokenizer.encode(response_text)),"total_tokens": inputs.input_ids.shape[1] + len(tokenizer.encode(response_text))}})# ===== 启动说明 =====
if __name__ == "__main__":uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info", access_log=False)
2、client.py
import requests
import json
import sys# ===== 配置 =====
BASE_URL = "http://localhost:8000/v1/chat/completions"
MODEL_NAME = "qwen3-4b" # 与服务端 model 字段对应即可,实际由服务端忽略def chat_completion(messages, stream=False, temperature=0.7, max_tokens=512):"""调用本地 Qwen3 API,兼容 OpenAI 格式"""payload = {"model": MODEL_NAME,"messages": messages,"temperature": temperature,"max_tokens": max_tokens,"stream": stream}headers = {"Content-Type": "application/json"}if stream:with requests.post(BASE_URL, json=payload, headers=headers, stream=True) as response:if response.status_code != 200:print(f"Error: {response.status_code} - {response.text}")returnfor line in response.iter_lines():# print(f"line: {line}")if line:decoded = line.decode('utf-8')if decoded.startswith("data: "):data = decoded[len("data: "):]if data.strip() == "[DONE]":print("\n[Response completed]")breaktry:chunk = json.loads(data)delta = chunk["choices"][0]["delta"]content = delta.get("content", "")if content:print(content, end="", flush=True)except Exception as e:print(f"\n[Parse error: {e}]")else:response = requests.post(BASE_URL, json=payload, headers=headers)if response.status_code == 200:data = response.json()message = data["choices"][0]["message"]["content"]print("Response:")print(data)print("\nUsage:", data.get("usage", {}))else:print(f"Error: {response.status_code} - {response.text}")if __name__ == "__main__":# 示例对话messages = [{"role": "system", "content": "你是一个乐于助人的 AI 助手。"},{"role": "user", "content": "你好!请介绍一下你自己。"}]chat_completion(messages, stream=True)