【deepseek】本地部署后api接口的封装
文章目录
- 前言
- 一、 安装必要的库
- 二、 创建 FastAPI 应用
- 三、 运行 FastAPI 应用
- 四、 使用 API
- 五、 客户端api调用
前言
本文内容包含如何将 Deepseek本地模型的RAG 知识库文件读取、对话选择、对话回复和历史对话封装成 API 供其他应用调用的具体步骤,本地Deepseek模型部署及知识库挂载请参考上篇博客《【deepseek】本地部署+RAG知识库挂载+对话测试》
一、 安装必要的库
pip install fastapi uvicorn python-multipart langchain transformers sentence-transformers faiss-cpu unstructured pdf2image python-docx python-pptx
fastapi
: Web 框架。uvicorn
: ASGI 服务器,用于运行 FastAPI 应用。python-multipart
: 用于处理文件上传。python-pptx
: 用于读取PPT文件
二、 创建 FastAPI 应用
api.py
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from typing import List, Dict
import os
import shutil
import datetime
import json # 用于处理历史对话记录
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader, PyPDFLoader, Docx2txtLoader, UnstructuredFileLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from pptx import Presentation # 用于读取PPT文件
app = FastAPI()
# 全局变量 (根据实际情况修改)
MODEL_PATH = "./deepseek-llm-7b-chat"
DB_FAISS_PATH = "vectorstore/db_faiss"
UPLOAD_FOLDER = "uploads" # 用于保存上传的知识库文件
HISTORY_FILE = "history.json" # 用于保存历史对话记录
os.makedirs(UPLOAD_FOLDER, exist_ok=True) # 确保上传文件夹存在
# 初始化模型、tokenizer 和向量数据库 (在应用启动时加载)
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.float16)
pipe = pipeline("text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto",
max_new_tokens=256,
do_sample=True,
top_p=0.9,
temperature=0.7,
num_return_sequences=1)
llm = HuggingFacePipeline(pipeline=pipe)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
# 自定义 Prompt
prompt_template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答案。
上下文:{context}
问题:{question}
有用的回答:"""
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 3}),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt}
)
print("模型和向量数据库加载成功!")
except Exception as e:
print(f"模型或向量数据库加载失败: {e}")
raise # 抛出异常,阻止应用启动
# 辅助函数
def load_document(file_path: str):
"""加载单个文档"""
try:
if file_path.endswith(".txt"):
loader = TextLoader(file_path, encoding="utf-8")
elif file_path.endswith(".pdf"):
loader = PyPDFLoader(file_path)
elif file_path.endswith(".docx") or file_path.endswith(".doc"):
try:
loader = Docx2txtLoader(file_path)
except ImportError:
print(f"docx2txt 未安装,尝试使用 UnstructuredFileLoader 加载 {file_path}")
loader = UnstructuredFileLoader(file_path)
elif file_path.endswith(".ppt") or file_path.endswith(".pptx"):
# 使用 python-pptx 加载 PPT 文件
text = ""
try:
prs = Presentation(file_path)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text
except Exception as e:
print(f"加载 PPT 文件失败: {e}")
raise
# 将 PPT 内容视为一个文本文件加载
loader = TextLoader(file_path, encoding="utf-8")
loader.load()[0].page_content = text # 将提取的文本内容写入 page_content
return loader.load() # 返回加载的文档列表
else:
loader = UnstructuredFileLoader(file_path)
return loader.load()
except Exception as e:
print(f"加载文档 {file_path} 失败: {e}")
raise
def create_vector_db_from_files(file_paths: List[str]):
"""从文件列表创建向量数据库"""
documents = []
for file_path in file_paths:
try:
documents.extend(load_document(file_path))
except Exception as e:
print(f"加载文档 {file_path} 失败: {e}")
continue # 忽略加载失败的文件
if not documents:
raise ValueError("没有成功加载任何文档")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts = text_splitter.split_documents(documents)
if not texts:
raise ValueError("没有成功分割任何文本块")
global db # 使用全局变量
db = FAISS.from_documents(texts, embeddings)
db.save_local(DB_FAISS_PATH)
print("向量数据库创建/更新完成!")
def load_history():
"""加载历史对话记录"""
try:
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
history = json.load(f)
except FileNotFoundError:
history = []
return history
def save_history(history: List[Dict[str, str]]):
"""保存历史对话记录"""
with open(HISTORY_FILE, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=4)
# API 接口
@app.post("/uploadfiles/")
async def upload_files(files: List[UploadFile] = File(...)):
"""上传知识库文件"""
file_paths = []
for file in files:
try:
file_path = os.path.join(UPLOAD_FOLDER, file.filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
file_paths.append(file_path)
except Exception as e:
return JSONResponse(status_code=500, content={"message": f"文件 {file.filename} 上传失败: {e}"})
finally:
file.file.close()
try:
create_vector_db_from_files(file_paths)
return {"message": "知识库文件上传成功,向量数据库已更新!"}
except ValueError as e:
return JSONResponse(status_code=400, content={"message": str(e)})
except Exception as e:
return JSONResponse(status_code=500, content={"message": f"创建向量数据库失败: {e}"})
@app.post("/chat/")
async def chat(query: str):
"""对话接口"""
try:
response = qa_chain({"query": query})
result = response["result"]
source_documents = response["source_documents"]
# 提取来源文档信息
sources = [doc.metadata['source'] for doc in source_documents]
# 保存对话历史
history = load_history()
history.append({"user": query, "bot": result, "sources": sources, "timestamp": datetime.datetime.now().isoformat()})
save_history(history)
return {"response": result, "sources": sources}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/history/")
async def get_history():
"""获取历史对话记录"""
history = load_history()
return history
# 启动信息
@app.get("/")
async def read_root():
return {"message": "RAG API is running!"}
- 全局变量:
MODEL_PATH
,DB_FAISS_PATH
,UPLOAD_FOLDER
,HISTORY_FILE
等。 根据实际情况修改。 - 初始化模型和向量数据库: 在应用启动时加载模型和向量数据库,避免每次请求都重新加载。 使用
try...except
块捕获加载错误,并阻止应用启动,确保在模型和向量数据库加载成功后才能提供服务。 load_document()
函数: 用于加载单个文档。 支持 text, pdf, docx, ppt 等格式。create_vector_db_from_files()
函数: 用于从文件列表创建或更新向量数据库。load_history()
和save_history()
函数: 用于加载和保存历史对话记录。 历史记录保存在HISTORY_FILE
中,使用 JSON 格式存储。/uploadfiles/
接口 (POST): 用于上传知识库文件。 接收一个或多个文件,保存到UPLOAD_FOLDER
,然后调用create_vector_db_from_files()
创建或更新向量数据库。/chat/
接口 (POST): 用于对话。 接收用户输入的query
,调用 RAG 链生成回复,并将对话记录保存到历史记录中。 返回回复内容和来源文档信息。/history/
接口 (GET): 用于获取历史对话记录。 返回保存在HISTORY_FILE
中的历史记录。/
接口 (GET): 用于检查API是否正常启动
三、 运行 FastAPI 应用
在终端中运行以下命令:
uvicorn api:app --reload
api
:api.py
文件名 (不带.py
后缀)。app
: FastAPI 应用实例的名称。--reload
: 启用自动重载,当代码修改时,服务器会自动重启。
四、 使用 API
接下来可以使用任何 HTTP 客户端 (例如 curl
, Postman
, requests
库等) 来访问 API。
-
上传文件:
curl -X POST -F "files=@knowledge1.txt" -F "files=@knowledge2.pdf" http://localhost:8000/uploadfiles/
-
对话:
curl -X POST -H "Content-Type: application/json" -d '{"query": "什么是XXX?"}' http://localhost:8000/chat/
-
获取历史记录:
curl http://localhost:8000/history/
五、 客户端api调用
使用 Python 的 requests
库来调用 API
import requests
import json
API_URL = "http://your_server_ip:8000"
def upload_files(file_paths):
"""上传文件到 API"""
files = []
for file_path in file_paths:
files.append(("files", open(file_path, "rb")))
response = requests.post(f"{API_URL}/uploadfiles/", files=files)
return response.json()
def chat(query):
"""与 API 对话"""
headers = {"Content-Type": "application/json"}
data = {"query": query}
response = requests.post(f"{API_URL}/chat/", headers=headers, data=json.dumps(data))
return response.json()
def get_history():
"""获取对话历史"""
response = requests.get(f"{API_URL}/history/")
return response.json()
if __name__ == "__main__":
# 上传文件示例
# upload_result = upload_files(["knowledge1.txt", "knowledge2.pdf"])
# print("上传结果:", upload_result)
# 对话示例
while True:
query = input("你: ")
if query.lower() == "exit":
break
chat_result = chat(query)
print("AI:", chat_result["response"])
print("来源:", chat_result["sources"])
# 获取历史记录示例
# history = get_history()
# print("历史记录:", history)
API_URL
: 替换为实际的 FastAPI 服务器的 IP 地址和端口。upload_files()
: 使用requests.post()
上传文件。chat()
: 使用requests.post()
发送对话请求。get_history()
: 使用requests.get()
获取对话历史。
脚本要点:
- 异常处理: 在 API 和客户端代码中添加适当的异常处理,以提高程序的健壮性。
- 身份验证: 如果需要,可以添加身份验证机制,例如 API 密钥或 JWT。
- 日志记录: 添加日志记录功能,方便调试和监控。
- 并发处理: FastAPI 支持异步处理,可以提高并发性能。
- 模型和向量数据库的加载: 确保模型和向量数据库只加载一次,避免重复加载。
- 文件上传安全: 对上传的文件进行安全检查,防止恶意文件上传。
- Prompt 优化: 根据实际情况优化 Prompt,提高回复质量。
- 向量数据库更新策略: 考虑如何更新向量数据库,例如定期更新或增量更新。
- 错误处理: 在API中增加详细的错误处理,方便客户端调试。
其他应用建议:
- 对话选择接口: 可以添加一个接口,让用户选择使用哪个知识库进行对话。
- 多轮对话管理: 可以使用 Langchain 的
ConversationChain
或ConversationBufferMemory
来管理多轮对话。 - 自定义 Prompt: 允许用户自定义 Prompt,以满足不同的需求。
- 流式回复: 可以使用 FastAPI 的
StreamingResponse
来实现流式回复,提高用户体验。