从零构建RAG知识库管理系统(二)
第二篇:《后端核心实现:FastAPI服务搭建》
引言
在上一篇文章中,我们对RAG知识库管理系统的整体架构进行了概述。本文将深入探讨后端服务的核心实现,重点介绍FastAPI框架的配置、路由设计、数据库集成以及用户认证系统的实现。
FastAPI框架基础配置
FastAPI是一个现代、快速(高性能)的Python Web框架,基于标准Python类型提示构建。它具有快速编码、快速执行和快速开发的特点。
主应用配置 (main.py)
代码功能简短总结
- 环境与路径配置:导入
os
、sys
等模块,加载.env
文件环境变量,并将项目根目录添加到Python路径,确保模块正常导入。 - 依赖与路由导入:引入
FastAPI
及CORS中间件,导入认证(auth_router
)、知识库管理(kb_router
,含文件上传)、聊天(chat_router
)三类路由,同时导入数据库表创建函数。 - FastAPI应用初始化:创建标题为“RAG Project API”、版本1.0.0的应用,调用
create_tables()
创建数据库表。 - CORS配置:添加CORS中间件,允许所有源(生产环境需指定具体域名)、凭证、请求方法及请求头,解决跨域问题。
- 路由与根路径注册:包含上述三类路由,定义根路径
/
,访问时返回消息“我是RAG项目API”。 - 运行配置:若直接执行该文件,通过
uvicorn
在127.0.0.1:8080
启动服务。
import os
import sys
from dotenv import load_dotenv# 加载.env文件
load_dotenv()# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware# 导入路由
from rag_back.fastapi.auth import router as auth_router
from rag_back.fastapi.kb_routes import router as kb_router # 知识库管理路由包含了文件上传功能
from rag_back.fastapi.chat import router as chat_router # 聊天路由
from rag_back.persistent.database import create_tables# 创建FastAPI应用
app = FastAPI(title="RAG Project API", version="1.0.0")# 创建数据库表
create_tables()# 配置CORS
app.add_middleware(CORSMiddleware,allow_origins=["*"], # 在生产环境中应该指定具体的域名allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)# 包含路由
app.include_router(auth_router)
app.include_router(kb_router) # 注册知识库管理路由
app.include_router(chat_router) # 注册聊天路由# 根路径
@app.get("/")
async def root():return {"message": "我是RAG项目API"}if __name__ == "__main__":import uvicornuvicorn.run(app, host="127.0.0.1", port=8080)
在主应用配置中,我们首先加载了环境变量,然后创建了FastAPI应用实例。通过CORS中间件配置,我们允许前端应用与后端API进行跨域通信。最后,我们将各个功能模块的路由注册到主应用中。
路由设计
系统采用模块化路由设计,将不同功能的API接口分别组织在不同的路由文件中。
认证路由 (auth.py)
代码简短总结
该代码基于FastAPI实现用户认证功能,核心要点如下:
依赖与配置
- 导入FastAPI、JWT、密码加密(passlib)、SQLAlchemy等依赖
- 配置JWT(密钥、算法HS256、令牌有效期30分钟)、数据库会话(SessionLocal)
- 定义路由前缀
/auth
,标签["auth"]
数据模型
Token
:返回令牌结构(access_token、token_type)TokenData
:令牌解析后的用户数据(含可选username)UserCreate
:用户注册请求体(username、email、full_name、password)
核心工具函数
- 数据库:
get_db()
获取数据库会话(自动关闭)- 密码处理:
verify_password()
验证密码、get_password_hash()
生成密码哈希 - 用户操作:
get_user()
查用户、authenticate_user()
验证用户名密码 - JWT操作:
create_access_token()
生成访问令牌、get_current_user()
解析令牌获当前用户
- 密码处理:
接口功能
- 注册接口(
POST /auth/register
):校验用户名唯一性,加密密码存入数据库,返回令牌- 登录接口(
POST /auth/token
):通过OAuth2表单验证用户,成功后返回令牌 - 异常处理:用户名已注册(400)、 credentials无效/用户名密码错误(401)
- 登录接口(
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel# 导入数据库相关
from rag_back.persistent.database import SessionLocal, User
from sqlalchemy.orm import Session# 导入配置
from rag_back.rag.config import RagConfig# 创建路由实例
router = APIRouter(prefix="/auth", tags=["auth"])# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")# OAuth2密码Bearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")# JWT配置
SECRET_KEY = "your-secret-key" # 在实际应用中应该从环境变量获取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30class Token(BaseModel):access_token: strtoken_type: strclass TokenData(BaseModel):username: Optional[str] = Noneclass UserCreate(BaseModel):username: stremail: strfull_name: strpassword: str# 获取数据库会话
def get_db():db = SessionLocal()try:yield dbfinally:db.close()# 验证密码
def verify_password(plain_password, hashed_password):return pwd_context.verify(plain_password, hashed_password)# 获取密码哈希值
def get_password_hash(password):return pwd_context.hash(password)# 根据用户名获取用户
def get_user(db: Session, username: str):return db.query(User).filter(User.username == username).first()# 验证用户
def authenticate_user(db: Session, username: str, password: str):user = get_user(db, username)if not user:return Falseif not verify_password(password, user.hashed_password):return Falsereturn user# 创建访问令牌
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):to_encode = data.copy()if expires_delta:expire = datetime.utcnow() + expires_deltaelse:expire = datetime.utcnow() + timedelta(minutes=15)to_encode.update({"exp": expire})encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)return encoded_jwt# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail="Could not validate credentials",headers={"WWW-Authenticate": "Bearer"},)try:payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])username: str = payload.get("sub")if username is None:raise credentials_exceptiontoken_data = TokenData(username=username)except JWTError:raise credentials_exceptionuser = get_user(db, username=token_data.username)if user is None:raise credentials_exceptionreturn user# 用户注册接口
@router.post("/register", response_model=Token)
def register_user(user: UserCreate, db: Session = Depends(get_db)):db_user = get_user(db, username=user.username)if db_user:raise HTTPException(status_code=400, detail="Username already registered")hashed_password = get_password_hash(user.password)db_user = User(username=user.username,email=user.email,full_name=user.full_name,hashed_password=hashed_password)db.add(db_user)db.commit()db.refresh(db_user)# 创建访问令牌access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)return {"access_token": access_token, "token_type": "bearer"}# 获取访问令牌接口
@router.post("/token", response_model=Token)
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):user = authenticate_user(db, form_data.username, form_data.password)if not user:raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail="Incorrect username or password",headers={"WWW-Authenticate": "Bearer"},)access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)return {"access_token": access_token, "token_type": "bearer"}
知识库管理路由 (kb_routes.py)
代码功能简短总结
该代码是基于FastAPI构建的知识库(Knowledge Base)后端接口模块,核心功能围绕知识库的管理与文件处理,要点如下:
- 核心依赖与初始化
- 依赖模块:FastAPI(接口)、Milvus(向量数据库)、MinIO(文件存储)、SQLAlchemy(关系型数据库)、LlamaIndex(嵌入模型)等
- 初始化操作:创建MinIO客户端并确保存储桶存在、连接Milvus向量数据库、创建本地文档目录、初始化FastAPI路由(前缀
/api/kb
,标签knowledge_base
)
- 初始化操作:创建MinIO客户端并确保存储桶存在、连接Milvus向量数据库、创建本地文档目录、初始化FastAPI路由(前缀
- 数据库核心操作函数
get_db()
:获取SQLAlchemy数据库会话(自动关闭)create_knowledge_base()
:在关系库中创建知识库记录(含名称、描述)get_knowledge_base_by_name()
/get_all_knowledge_bases()
:按名称查询/查询所有知识库
- 核心接口功能
- GET /api/kb/collections:获取所有知识库列表
- 从关系库读取知识库信息,将UTC创建时间转为北京时区(+8小时)
- 含异常处理,返回成功/失败状态(封装为
R
对象) - 返回字段:知识库名称、描述、创建时间、文件数量
- 含异常处理,返回成功/失败状态(封装为
- POST /api/kb/collections:创建知识库集合并可选上传文件
- 先校验Milvus/关系库中是否已存在该知识库,避免重复
- 异常时回滚(如创建空集合失败则删除关系库记录)
- 有文件时:处理文件并创建索引,更新关系库中文件数量
- 无文件时:创建空Milvus集合,同步关系库记录
- 关键配置与工具
- 嵌入模型:使用本地
embed_model_local_bge_small
,避免依赖外部API- 时间处理:统一将UTC时间转为北京时区展示
- 存储保障:MinIO存储桶/本地文档目录自动创建,确保存储可用
import os
import uuid
from typing import List
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pymilvus import utility, connections, Collection, CollectionSchema, FieldSchema, DataType, MilvusClient# 导入必要的模块
from rag_back.rag.traditional_rag import TraditionalRAG
from rag_back.utils.r import R
import time
# 导入MinIO客户端
from minio import Minio
# 导入配置类
from rag_back.rag.config import RagConfig
# 导入Settings和本地嵌入模型
from llama_index.core import Settings
from rag_back.rag.embeddings import embed_model_local_bge_small
from datetime import datetime, timedelta# 导入数据库相关
from rag_back.persistent.database import SessionLocal, KnowledgeBase, User
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from sqlalchemy import update# 导入认证相关
from rag_back.fastapi.auth import get_current_user# 创建路由实例
router = APIRouter(prefix="/api/kb", tags=["knowledge_base"])# 初始化MinIO客户端
minio_client = Minio(endpoint=RagConfig.MINIO_ENDPOINT,access_key=RagConfig.MINIO_ACCESS_KEY,secret_key=RagConfig.MINIO_SECRET_KEY,secure=False # 根据你的配置设置为True或False
)# 确保文件夹存在
documents_dir = os.path.join(os.getcwd(), "documents")
os.makedirs(documents_dir, exist_ok=True)# 确保存储桶存在
try:if not minio_client.bucket_exists(RagConfig.MINIO_BUCKET_NAME):minio_client.make_bucket(RagConfig.MINIO_BUCKET_NAME)
except Exception as e:print(f"创建MinIO存储桶失败: {e}")# 连接到Milvus
try:connections.connect(uri=RagConfig.Milvus_uri)
except Exception as e:print(f"连接Milvus失败: {e}")def get_db():"""获取数据库会话"""db = SessionLocal()try:yield dbfinally:db.close()def create_knowledge_base(db: Session, name: str, description: str = ""):"""创建知识库记录"""kb = KnowledgeBase(name=name, description=description)db.add(kb)db.commit()db.refresh(kb)return kbdef get_knowledge_base_by_name(db: Session, name: str):"""根据名称获取知识库"""return db.query(KnowledgeBase).filter(KnowledgeBase.name == name).first()def get_all_knowledge_bases(db: Session):"""获取所有知识库"""return db.query(KnowledgeBase).all()@router.get("/collections")
async def get_collections():"""获取所有知识库集合列表及创建时间"""try:db = next(get_db())# 获取数据库中的所有知识库knowledge_bases = get_all_knowledge_bases(db)# 构造返回数据,包含名称、创建时间和文件数量result = []for kb in knowledge_bases:# 将UTC时间转换为本地时间created_at_local = getattr(kb, 'created_at', None)if created_at_local is not None:# 转换为东八区时间(北京时间)# kb.created_at 在查询后会自动转换为 datetime 对象created_at_beijing = created_at_local + timedelta(hours=8)created_at_str = created_at_beijing.strftime("%Y-%m-%d %H:%M:%S")else:created_at_str = "未知"result.append({"name": kb.name,"description": kb.description,"created_at": created_at_str,"file_count": kb.file_count})db.close()return R.ok("获取知识库列表成功", data=result)except Exception as e:return R.error(f"获取知识库列表失败: {str(e)}")@router.post("/collections")
async def create_collection_with_files(collection_name: str = Form(...),description: str = Form(""),files: List[UploadFile] = File(default=[])
):"""创建新的知识库集合并可选择性地上传文件"""db = next(get_db())try:# 检查集合是否已存在(在Milvus中)if utility.has_collection(collection_name):db.close()return R.error("知识库已存在", status_code=400)# 在数据库中创建知识库记录try:kb = create_knowledge_base(db, collection_name, description)except IntegrityError:db.close()return R.error("知识库已存在", status_code=400)# 如果有文件需要上传,则处理文件上传和索引创建if files and len(files) > 0:try:result_data = await process_files_and_create_index(files, collection_name)# 更新文件数量processed_files = result_data["processed_files"]increment_knowledge_base_file_count(db, collection_name, processed_files)db.close()return R.ok(f"成功处理 {processed_files} 个文件", data=result_data)except Exception as e:db.close()return R.error(f"处理文件时出错: {str(e)}")else:# 即使没有文件,也要在Milvus中创建空集合try:# 设置本地嵌入模型,避免使用OpenAISettings.embed_model = embed_model_local_bge_small()# 创建一个空的RAG实例rag = TraditionalRAG(files=[])# 创建索引(这会在Milvus中创建空集合)await rag.create_index_remote(collection_name=collection_name)db.close()return R.ok(f"知识库 '{collection_name}' 创建成功")except Exception as e:# 如果创建空集合失败,删除数据库记录delete_knowledge_base(db, collection_name)db.close()return R.error(f"创建知识库失败: {str(e)}")except Exception as e:db.close()return R.error(f"创建知识库失败: {str(e)}")
聊天路由 (chat.py)
代码简短总结
该代码是一个基于FastAPI的RAG(检索增强生成)聊天系统后端核心模块,要点如下:
1. 核心依赖与配置
- 基础库:导入os、sys、asyncio、sqlalchemy等,用于路径处理、异步操作、数据库交互
- 第三方库:FastAPI(接口开发)、MinIO(对象存储)、llama-index(RAG核心)、pydantic(数据校验)
- 配置项:日志初始化、MinIO客户端初始化(指定端点/密钥/存储桶)、项目路径添加、数据库会话创建
2. 数据模型(Pydantic)
定义接口请求/响应的数据结构,确保数据格式合规:
模型名 | 用途 | 关键字段 |
SessionCreate/Update | 会话创建/更新 | title(会话标题) |
SessionResponse | 会话响应 | id/user_id/title/created_at/updated_at |
MessageCreate | 消息创建 | session_id/content/model(默认qianwen)/collection(知识库) |
MessageResponse | 消息响应 | id/session_id/role/content/created_at/model |
SuccessResponse/ListResponse | 统一响应 | code(默认200)/message/data |
3. 核心工具函数
get_llm_by_name
:根据模型名(deepseek/moonshot/qianwen)返回对应LLM实例get_db
:数据库依赖函数,生成数据库会话并自动关闭- 时间序列化:
serialize_datetime
,将datetime转为ISO格式字符串
4. API接口(FastAPI路由)
前缀为/chat
,标签为chat
,需用户认证(get_current_user
依赖):
- /user/me:GET请求,获取当前登录用户信息(id/username/email/full_name)
- /sessions(POST):创建会话,接收标题,返回会话详情
- /sessions(GET):获取当前用户所有会话,按更新时间倒序排列,返回会话列表
5. 全局状态与存储
upload_tasks
:字典,存储后台任务状态(session_id为键,包含status/progress/message)- 存储处理:创建本地
documents
文件夹、检查MinIO存储桶是否存在(不存在则创建) - 数据库关联:关联User(用户)、ChatSession(会话)、Message(消息)、SessionFile(会话文件)表
import os
import sys
import uuid
import time # 添加time模块导入
import asyncio
import json
import logging
import threading # 添加线程支持
from typing import List, Optional, Any, AsyncGenerator
from functools import partial# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel, field_serializer
from datetime import datetime# 添加日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from rag_back.persistent.database import SessionLocal, User, Session as ChatSession, Message, SessionFile, get_beijing_time
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from rag_back.rag.llms import qianwen_llm, moonshot_llm, deepseek_llm# 导入正确的用户认证函数
from rag_back.fastapi.auth import get_current_user# 导入MinIO客户端
from minio import Minio# 导入配置类
from rag_back.rag.config import RagConfig# 导入RAG相关模块
from rag_back.rag.traditional_rag import TraditionalRAG
from rag_back.rag.base_rag import RAG
from rag_back.rag.embeddings import embed_model_local_bge_small# 导入Milvus工具
from rag_back.utils.milvus import drop_collection# 导入响应工具
from rag_back.utils.r import R# 创建路由实例 - 移除 /api 前缀,因为前端直接调用 /chat
router = APIRouter(prefix="/chat", tags=["chat"])# 全局任务状态字典,用于存储后台任务的处理状态
upload_tasks = {} # {session_id: {"status": "queued/processing/completed/failed", "progress": 0-100, "message": ""}}# 初始化MinIO客户端
minio_client = Minio(endpoint=RagConfig.MINIO_ENDPOINT,access_key=RagConfig.MINIO_ACCESS_KEY,secret_key=RagConfig.MINIO_SECRET_KEY,secure=False # 根据你的配置设置为True或False
)# 确保文件夹存在
documents_dir = os.path.join(os.getcwd(), "documents")
os.makedirs(documents_dir, exist_ok=True)# 确保存储桶存在
try:if not minio_client.bucket_exists(RagConfig.MINIO_BUCKET_NAME):minio_client.make_bucket(RagConfig.MINIO_BUCKET_NAME)
except Exception as e:print(f"创建MinIO存储桶失败: {e}")# 数据库依赖
def get_db():db = SessionLocal()try:yield dbfinally:db.close()# Pydantic模型
class SessionCreate(BaseModel):title: strclass SessionUpdate(BaseModel):title: strclass SessionResponse(BaseModel):id: intuser_id: inttitle: strcreated_at: datetimeupdated_at: datetime@field_serializer('created_at', 'updated_at')def serialize_datetime(self, dt: datetime) -> str:# 确保时间以ISO格式正确序列化return dt.isoformat()class Config:from_attributes = Trueclass MessageCreate(BaseModel):session_id: intcontent: str# 添加模型选择字段model: Optional[str] = "qianwen" # 默认使用通义千问# 添加知识库选择字段collection: Optional[str] = None # 选择的公共知识库class MessageResponse(BaseModel):id: intsession_id: introle: strcontent: strcreated_at: datetime# 添加模型字段model: Optional[str] = None@field_serializer('created_at')def serialize_datetime(self, dt: datetime) -> str:# 确保时间以ISO格式正确序列化return dt.isoformat()class Config:from_attributes = Trueclass ChatResponse(BaseModel):message: MessageResponsesession: SessionResponseclass SuccessResponse(BaseModel):code: int = 200message: strdata: Optional[Any] = Noneclass ListResponse(BaseModel):code: int = 200message: strdata: List[Any]# 根据模型名称获取对应的LLM实例
def get_llm_by_name(model_name: str):if model_name == "deepseek":return deepseek_llm()elif model_name == "moonshot":return moonshot_llm()else: # 默认使用通义千问return qianwen_llm()# 获取当前用户信息
@router.get("/user/me", response_model=SuccessResponse)
async def get_user_info(user: User = Depends(get_current_user)):"""获取当前登录用户的信息"""return SuccessResponse(message="获取用户信息成功",data={"id": user.id,"username": user.username,"email": user.email,"full_name": user.full_name})# 创建会话
@router.post("/sessions", response_model=SuccessResponse)
def create_session(session_data: SessionCreate,user: User = Depends(get_current_user),db: Session = Depends(get_db)
):try:db_session = ChatSession(user_id=user.id,title=session_data.title)db.add(db_session)db.commit()db.refresh(db_session)# 手动创建响应对象,避免关系字段问题session_response = SessionResponse(id=getattr(db_session, 'id'),user_id=getattr(db_session, 'user_id'),title=getattr(db_session, 'title'),created_at=getattr(db_session, 'created_at'),updated_at=getattr(db_session, 'updated_at'))return SuccessResponse(message="会话创建成功", data=session_response)except Exception as e:db.rollback()raise HTTPException(status_code=500, detail=f"创建会话失败: {str(e)}")# 获取用户的所有会话
@router.get("/sessions", response_model=ListResponse)
def get_sessions(user: User = Depends(get_current_user),db: Session = Depends(get_db)
):try:sessions = db.query(ChatSession).filter(ChatSession.user_id == user.id).order_by(ChatSession.updated_at.desc()).all()session_responses = [SessionResponse(id=getattr(session, 'id'),user_id=getattr(session, 'user_id'),title=getattr(session, 'title'),created_at=getattr(session, 'created_at'),updated_at=getattr(session, 'updated_at')) for session in sessions]return ListResponse(message="获取会话列表成功", data=session_responses)except Exception as e:raise HTTPException(status_code=500, detail=f"获取会话列表失败: {str(e)}")
数据库模型设计与SQLAlchemy集成
系统使用SQLAlchemy作为ORM工具,与MySQL数据库进行交互。
数据库模型定义 (database.py)
代码功能简短总结
该代码是基于Python和SQLAlchemy的MySQL数据库模型定义脚本,核心功能为RAG(检索增强生成)项目构建数据库结构,要点如下:
- 环境与路径配置
- 加载.env文件读取环境变量,默认值兜底数据库配置
- 将项目根目录添加到Python路径,确保模块可导入
- 数据库连接设置
- 从环境变量获取MySQL的主机、端口、用户名、密码、数据库名
- 构建数据库URL,创建SQLAlchemy引擎(关闭日志打印)和会话工厂(禁用自动提交/刷新)
- 工具函数定义
get_beijing_time()
:获取当前UTC时间并转换为北京时间,返回无时区信息的时间对象
- 5个核心数据模型(表结构)
模型类名 | 对应表名 | 核心字段/功能 |
|
| 存储用户信息(用户名、邮箱、加密密码等),与Session为一对多关系 |
|
| 存储知识库信息(名称、描述、文件数等),含专属提示词模板字段(占位符{context}/{query}) |
|
| 存储会话信息(标题、所属用户等),与User、Message、SessionFile分别建立关联,支持级联删除 |
|
| 存储会话消息(角色user/assistant、内容、使用模型等),关联所属Session |
|
| 存储会话上传文件信息(原始名、保存路径、MinIO对象名等),关联所属Session |
- 表创建功能
create_tables()
:通过SQLAlchemy基类metadata创建所有定义的数据库表
import os
import sys
from dotenv import load_dotenv# 加载.env文件
load_dotenv()# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from datetime import datetime, timezone, timedelta# 从环境变量获取数据库配置
MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
MYSQL_PORT = os.getenv("MYSQL_PORT", "3306")
MYSQL_USER = os.getenv("MYSQL_USER", "root")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "123456")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "rag_project")# 创建数据库URL
DATABASE_URL = f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DATABASE}"# 创建引擎
engine = create_engine(DATABASE_URL, echo=False)# 创建会话
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)# 创建基类
Base = declarative_base()# 获取北京时间的函数
def get_beijing_time():"""获取当前北京时间"""utc_time = datetime.utcnow()beijing_time = utc_time.replace(tzinfo=timezone.utc).astimezone(timezone(timedelta(hours=8)))return beijing_time.replace(tzinfo=None) # 返回不带时区信息的时间,保持与原来一致# 用户模型
class User(Base):__tablename__ = "users"id = Column(Integer, primary_key=True, index=True)username = Column(String(50), unique=True, index=True, nullable=False)email = Column(String(100), unique=True, index=True, nullable=False)full_name = Column(String(100), nullable=False)hashed_password = Column(String(100), nullable=False)created_at = Column(DateTime, default=get_beijing_time)updated_at = Column(DateTime, default=get_beijing_time, onupdate=get_beijing_time)# 关系sessions = relationship("Session", back_populates="user")# 知识库模型
class KnowledgeBase(Base):__tablename__ = "knowledge_bases"id = Column(Integer, primary_key=True, index=True)name = Column(String(255), unique=True, index=True, nullable=False)description = Column(Text, nullable=True)file_count = Column(Integer, default=0)# 新增:知识库专属提示词模板system_prompt = Column(Text, nullable=True, comment="知识库专属提示词模板,使用{context}和{query}作为占位符")created_at = Column(DateTime, default=get_beijing_time)updated_at = Column(DateTime, default=get_beijing_time, onupdate=get_beijing_time)# 会话模型
class Session(Base):__tablename__ = "sessions"id = Column(Integer, primary_key=True, index=True)user_id = Column(Integer, ForeignKey("users.id"), nullable=False)title = Column(String(255), nullable=False)created_at = Column(DateTime, default=get_beijing_time)updated_at = Column(DateTime, default=get_beijing_time, onupdate=get_beijing_time)# 关系user = relationship("User", back_populates="sessions")messages = relationship("Message", back_populates="session", cascade="all, delete-orphan")files = relationship("SessionFile", back_populates="session", cascade="all, delete-orphan")# 消息模型
class Message(Base):__tablename__ = "messages"id = Column(Integer, primary_key=True, index=True)session_id = Column(Integer, ForeignKey("sessions.id"), nullable=False)role = Column(String(50), nullable=False) # 'user' 或 'assistant'content = Column(Text, nullable=False)# 添加模型字段model = Column(String(50), nullable=True) # 使用的模型名称created_at = Column(DateTime, default=get_beijing_time)# 关系session = relationship("Session", back_populates="messages")# 会话文件模型
class SessionFile(Base):__tablename__ = "session_files"id = Column(Integer, primary_key=True, index=True)session_id = Column(Integer, ForeignKey("sessions.id"), nullable=False)original_name = Column(String(255), nullable=False)saved_path = Column(String(500), nullable=False)minio_object_name = Column(String(500), nullable=False)uploaded_at = Column(DateTime, default=get_beijing_time)# 关系session = relationship("Session", back_populates="files")# 创建表
def create_tables():Base.metadata.create_all(bind=engine)
用户认证系统实现(JWT Token)
用户认证系统是系统安全的重要组成部分,我们使用JWT(JSON Web Token)来实现用户认证。
JWT配置和工具函数
代码简短总结
该代码主要实现密码加密验证与JWT访问令牌生成两大核心功能,具体要点如下:
- 依赖库引入:导入
datetime
(时间处理)、jose
(JWT编解码)、passlib
(密码加密)相关库及类型提示工具; - 密码加密配置:使用
CryptContext
初始化bcrypt加密上下文,自动处理过时加密方案; - 密码操作函数:
verify_password
:验证明文密码与加密后密码是否匹配;get_password_hash
:将明文密码转换为bcrypt加密哈希值;
- JWT配置:定义密钥(
SECRET_KEY
,注:实际需从环境变量获取)、加密算法(HS256
)、默认令牌过期时间(30分钟); - 令牌生成函数:
create_access_token
接收自定义数据,可指定过期时间(默认15分钟),生成带过期时间的JWT令牌。
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")# JWT配置
SECRET_KEY = "your-secret-key" # 在实际应用中应该从环境变量获取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30# 验证密码
def verify_password(plain_password, hashed_password):return pwd_context.verify(plain_password, hashed_password)# 获取密码哈希值
def get_password_hash(password):return pwd_context.hash(password)# 创建访问令牌
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):to_encode = data.copy()if expires_delta:expire = datetime.utcnow() + expires_deltaelse:expire = datetime.utcnow() + timedelta(minutes=15)to_encode.update({"exp": expire})encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)return encoded_jwt
认证流程
- 用户通过登录接口提交用户名和密码
- 系统验证用户凭据,如果验证成功则生成JWT Token
- 客户端在后续请求中携带JWT Token
- 服务端通过中间件验证Token的有效性
- 如果Token有效,则允许访问受保护的资源
代码简短总结
导入依赖:引入FastAPI相关模块(Depends、HTTPException、status)及OAuth2密码认证方案模块(OAuth2PasswordBearer)。
初始化OAuth2认证:创建oauth2_scheme
实例,指定获取Token的接口地址为“auth/token”,用于获取请求中的Bearer Token。
定义获取当前用户函数:
- 函数
get_current_user
依赖oauth2_scheme
(获取Token)和get_db
(获取数据库会话)。- 定义401未授权异常
credentials_exception
,含“无法验证凭证”提示及Bearer认证头信息。 - 尝试解码JWT Token:用
SECRET_KEY
和指定ALGORITHM
解码Token,提取“sub”字段作为用户名,若不存在则抛异常;若解码失败(如Token无效),捕获JWTError
并抛异常。 - 验证用户存在性:通过
get_user
函数从数据库查询该用户名对应的用户,若用户不存在则抛异常,最终返回查询到的用户。
- 定义401未授权异常
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer# OAuth2密码Bearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail="Could not validate credentials",headers={"WWW-Authenticate": "Bearer"},)try:payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])username: str = payload.get("sub")if username is None:raise credentials_exceptiontoken_data = TokenData(username=username)except JWTError:raise credentials_exceptionuser = get_user(db, username=token_data.username)if user is None:raise credentials_exceptionreturn user
总结
本文详细介绍了RAG知识库管理系统后端服务的核心实现,包括:
- FastAPI框架配置:通过模块化设计和路由注册,构建了清晰的应用结构
- 路由设计:将认证、知识库管理和聊天功能分别组织在不同的路由模块中
- 数据库集成:使用SQLAlchemy ORM与MySQL数据库进行交互,定义了完整的数据模型
- 用户认证系统:基于JWT Token实现安全的用户认证机制
通过这些设计,系统具备了良好的可扩展性和安全性,为后续的功能开发奠定了坚实的基础。在下一篇文章中,我们将深入探讨RAG核心模块的实现细节。