D2-基于本地Ollama模型的多轮问答系统
本程序是一个基于 Gradio 和 Ollama API 构建的支持多轮对话的写作助手。相较于上一版本,本版本新增了对话历史记录、Token 计数、参数调节和清空对话功能,显著提升了用户体验和交互灵活性。
程序通过抽象基类 LLMAgent
实现模块化设计,当前使用 OllamaAgent
作为具体实现,调用本地部署的 Ollama 大语言模型(如 qwen3:8b
)生成写作建议,并提供一个交互式的 Web 界面供用户操作。
设计支持未来扩展到其他 LLM 平台(如 OpenAI、HuggingFace),只需实现新的 LLMAgent
子类即可。
环境配置
依赖安装
需要以下 Python 库:
gradio
:用于创建交互式 Web 界面。requests
:向 Ollama API 发送 HTTP 请求。json
:解析 API 响应数据(Python 内置)。logging
:记录运行日志(Python 内置)。abc
:定义抽象基类(Python 内置)。tiktoken
:精确计算 Token 数量以管理输入和历史长度。
安装命令:
pip install gradio requests tiktoken
建议使用 Python 3.8 或更高版本。
Ollama 服务配置
-
安装 Ollama
从 https://ollama.ai/ 下载并安装。 -
启动 Ollama 服务
ollama serve
- 默认监听地址:
http://localhost:11434
。
- 默认监听地址:
-
下载模型
ollama pull qwen3:8b
-
验证模型
ollama list
运行程序
- 将代码保存为
writing_assistant.py
。 - 确保 Ollama 服务正在运行。
- 执行程序:
python writing_assistant.py
- 打开浏览器访问界面(通常为
http://127.0.0.1:7860
)。 - 输入写作提示,调整参数后点击“获取写作建议”,查看结果和对话历史。
代码说明
1. 依赖导入
import gradio as gr
import requests
import json
import logging
from abc import ABC, abstractmethod
import tiktoken
tiktoken
:精确计算 Token 数量,优化输入控制。
2. 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
- 配置日志级别为
INFO
,记录 API 调用和错误信息,便于调试。
3. 抽象基类:LLMAgent
class LLMAgent(ABC):@abstractmethoddef generate_response(self, prompt):pass
- 定义通用 LLM 代理接口,要求实现
generate_response
方法。 - 支持未来扩展到其他 LLM 平台(如 OpenAI、Anthropic)。
4. 具体实现:OllamaAgent
class OllamaAgent(LLMAgent):def __init__(self, config): ...def set_max_history_length(self, max_rounds): ...def set_parameters(self, max_tokens, temperature): ...def generate_response(self, prompt): ...def clear_history(self): ...
- 对话历史:维护
history
列表,支持多轮对话。 - 参数调节:动态调整
max_tokens
和temperature
。 - Token 管理:自动截断历史记录,防止超出模型上下文限制。
- 错误处理:捕获网络请求失败和 JSON 解析错误,返回用户友好的提示。
5. Token 计数函数
def calculate_tokens(text): ...
def calculate_history_tokens(history): ...
- 使用
tiktoken
精确估算 Token 数量,提升输入长度控制能力。
6. 历史格式化
def format_history_for_chatbot(history): ...
- 将内部
history
结构转换为 Gradio 的Chatbot
格式[user_msg, assistant_msg]
。
7. 核心逻辑:generate_assistance
def generate_assistance(prompt, agent, max_rounds, max_tokens, temperature): ...
- 设置最大对话轮数和生成参数。
- 调用
agent.generate_response
获取响应。 - 返回格式化的对话历史、最新回复和 Token 计数。
8. 辅助函数
def update_token_count(prompt): ...
def clear_conversation(agent): ...
- 实时更新输入 Token 数量。
- 清空对话历史并重置状态。
9. 主函数:main
def main():config = { ... }agent = OllamaAgent(config)with gr.Blocks(...) as demo:...demo.launch()
- 增强的 UI:包含输入框、Token 显示、参数调节滑块和清空按钮。
- Gradio 事件绑定:
prompt_input.change()
:动态更新 Token 计数。submit_button.click()
:触发写作建议生成。clear_button.click()
:重置对话历史。
运行流程图
graph TDA[用户输入提示] --> B[点击 submit_button]B --> C[调用 generate_assistance(prompt, agent, 参数)]C --> D[调用 agent.set_* 设置参数]D --> E[调用 agent.generate_response(prompt)]E --> F[向 Ollama API 发送 POST 请求]F --> G[接收并解析 JSON 响应]G --> H[更新聊天历史和输出结果]
注意事项
- Ollama 服务:确保服务运行并监听在
http://localhost:11434/v1
。 - 模型可用性:确认
qwen3:8b
已下载。 - Token 上限:注意模型的最大上下文长度(如 4096 Tokens),避免历史过长导致超限。
- 参数影响:
temperature
:控制生成随机性(较低值更确定,较高值更具创造性)。max_tokens
:限制输出长度。
- 调试信息:查看终端日志,确认 API 响应是否正常或是否有错误。
未来改进建议
- 多模型支持:添加
OpenAIAgent
等子类,通过下拉菜单切换模型。 - 配置文件化:将硬编码配置移至 JSON/YAML 文件。
- 异步请求:使用
aiohttp
替换requests
,提升并发性能。 - 对话持久化:将历史对话保存到本地文件或数据库。
- 用户认证:区分不同用户的对话记录。
- 移动端适配:优化界面布局以适配手机端。
示例使用
- 启动程序后访问
http://127.0.0.1:7860
。 - 输入提示:“帮我写一段关于环保的文章。”
- 调整参数(如
max_tokens=1000
,temperature=0.2
)。 - 点击“获取写作建议”,查看类似以下输出:
在这里插入图片描述
代码
import gradio as gr
import requests
import json
import logging
from abc import ABC, abstractmethod
import tiktoken# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)# 抽象基类:定义通用的 LLM Agent 接口
class LLMAgent(ABC):@abstractmethoddef generate_response(self, prompt):pass# Ollama 特定的实现
class OllamaAgent(LLMAgent):def __init__(self, config):self.model = config["model"]self.base_url = config["base_url"]self.api_key = config["api_key"]self.max_tokens = config["max_tokens"]self.temperature = config["temperature"]self.history = []self.max_history_length = 10def set_max_history_length(self, max_rounds):self.max_history_length = int(max_rounds * 2)if len(self.history) > self.max_history_length:self.history = self.history[-self.max_history_length:]def set_parameters(self, max_tokens, temperature):self.max_tokens = int(max_tokens)self.temperature = float(temperature)def generate_response(self, prompt):self.history.append({"role": "user", "content": prompt})if len(self.history) > self.max_history_length:self.history = self.history[-self.max_history_length:]url = f"{self.base_url}/chat/completions"headers = {"Authorization": f"Bearer {self.api_key}","Content-Type": "application/json"}payload = {"model": self.model,"messages": self.history,"max_tokens": self.max_tokens,"temperature": self.temperature}try:response = requests.post(url, headers=headers, json=payload)response.raise_for_status()result = response.json()content = result['choices'][0]['message']['content']logger.info(f"API 响应: {content}")self.history.append({"role": "assistant", "content": content})return contentexcept requests.exceptions.RequestException as e:logger.error(f"API 请求失败: {str(e)}")return f"错误:无法连接到 Ollama API: {str(e)}"except KeyError as e:logger.error(f"解析响应失败: {str(e)}")return f"错误:解析响应失败: {str(e)}"def clear_history(self):self.history = []def calculate_tokens(text):if not text:return 0cleaned_text = text.strip().replace('\n', '')try:encoding = tiktoken.get_encoding("cl100k_base")tokens = encoding.encode(cleaned_text)return len(tokens)except Exception as e:logger.error(f"Token 计算失败: {str(e)}")return len(cleaned_text)def calculate_history_tokens(history):total_tokens = 0try:encoding = tiktoken.get_encoding("cl100k_base")for message in history:content = message["content"].strip()tokens = encoding.encode(content)total_tokens += len(tokens)return total_tokensexcept Exception as e:logger.error(f"历史 Token 计算失败: {str(e)}")return sum(len(msg["content"].strip()) for msg in history)def format_history_for_chatbot(history):"""将 agent.history 转换为 gr.Chatbot 所需格式:List[List[str, str]]"""messages = []for i in range(0, len(history) - 1, 2):if history[i]["role"] == "user" and history[i+1]["role"] == "assistant":messages.append([history[i]["content"], history[i+1]["content"]])return messagesdef generate_assistance(prompt, agent, max_rounds, max_tokens, temperature):agent.set_max_history_length(max_rounds)agent.set_parameters(max_tokens, temperature)response = agent.generate_response(prompt)history_tokens = calculate_history_tokens(agent.history)chatbot_format_history = format_history_for_chatbot(agent.history)return chatbot_format_history, response, f"历史总 token 数(估算):{history_tokens}"def update_token_count(prompt):return f"当前输入 token 数(精确):{calculate_tokens(prompt)}"def clear_conversation(agent):agent.clear_history()return [], "对话已清空", "历史总 token 数(估算):0"def main():config = {"api_type": "ollama","model": "qwen3:8b","base_url": "http://localhost:11434/v1","api_key": "ollama","max_tokens": 1000,"temperature": 0.2}agent = OllamaAgent(config)with gr.Blocks(title="写作助手") as demo:gr.Markdown("# 写作助手(支持多轮对话)")gr.Markdown("输入您的写作提示,获取建议和指导!支持连续对话,调整对话轮数、max_tokens 和 temperature,或点击“清空对话”重置。")with gr.Row():with gr.Column():prompt_input = gr.Textbox(label="请输入您的提示",placeholder="例如:帮我写一段关于环保的文章",lines=3)token_count = gr.Textbox(label="输入 token 数",value="当前输入 token 数(精确):0",interactive=False)history_token_count = gr.Textbox(label="历史 token 数",value="历史总 token 数(估算):0",interactive=False)max_rounds = gr.Slider(minimum=1,maximum=10,value=5,step=1,label="最大对话轮数",info="设置保留的对话轮数(每轮包含用户和模型消息)")max_tokens = gr.Slider(minimum=100,maximum=2000,value=1000,step=100,label="最大生成 token 数",info="控制单次生成的最大 token 数")temperature = gr.Slider(minimum=0.0,maximum=1.0,value=0.2,step=0.1,label="Temperature",info="控制生成随机性,0.0 为确定性,1.0 为较随机")submit_button = gr.Button("获取写作建议")clear_button = gr.Button("清空对话")with gr.Column():chatbot = gr.Chatbot(label="对话历史")output = gr.Textbox(label="最新生成结果", lines=5)prompt_input.change(fn=update_token_count,inputs=prompt_input,outputs=token_count)submit_button.click(fn=generate_assistance,inputs=[prompt_input, gr.State(value=agent), max_rounds, max_tokens, temperature],outputs=[chatbot, output, history_token_count])clear_button.click(fn=clear_conversation,inputs=gr.State(value=agent),outputs=[chatbot, output, history_token_count])demo.launch()if __name__ == "__main__":main()