【RAG】基于向量检索的 RAG (BGE示例)
RAG机器人 结构体
-
文本向量化: 使用 BGE 模型将文档和查询编码为向量。
(BGE 是专为检索任务优化的开源 Embedding 模型,除了本文API调用,也可以通过Hugging Face 本地部署BGE 开源模型) -
向量检索: 从数据库中找到与查询相关的文档片段。
-
答案生成: 结合检索结果和用户输入,调用文心模型生成最终回答。
class RAG_Bot:
def __init__(self, vector_db, llm_api, n_results=2):
self.vector_db = vector_db
self.llm_api = llm_api
self.n_results = n_results
def chat(self, user_query):
# 1. 检索
search_results = self.vector_db.search(user_query, self.n_results)
# 2. 构建 Prompt
prompt = build_prompt(
prompt_template, context=search_results['documents'][0], query=user_query)
# 3. 调用 LLM
response = self.llm_api(prompt)
return response
######
# 创建一个RAG机器人
bot = RAG_Bot(
vector_db,
llm_api=get_completion
)
user_query = "llama 2有多少参数?"
response = bot.chat(user_query)
print(response)
#####
llama 2有7B, 13B和70B参数。
MyVectorDBConnector:
自定义向量数据库,存储文档向量。
embedding_fn=get_embeddings_bge: 使用 BGE 模型生成向量。
add_documents(paragraphs): 向数据库中添加文档(已提前定义 paragraphs)。
RAG_Bot:
检索增强生成机器人,结合向量搜索与大模型生成。
chat(user_query): 执行“检索→生成”流程:
将用户查询向量化。
从数据库检索相关文档。
将检索结果作为上下文,调用文心模型生成回答。
使用国产模型
import json
import requests
import os
# 通过鉴权接口获取 access token
def get_access_token():
"""
使用 AK,SK 生成鉴权签名(Access Token)
:return: access_token,或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": os.getenv('ERNIE_CLIENT_ID'),
"client_secret": os.getenv('ERNIE_CLIENT_SECRET')
}
return str(requests.post(url, params=params).json().get("access_token"))
# 调用文心千帆 调用 BGE Embedding 接口
def get_embeddings_bge(prompts):
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en?access_token=" + get_access_token()
payload = json.dumps({
"input": prompts
})
headers = {'Content-Type': 'application/json'}
response = requests.request(
"POST", url, headers=headers, data=payload).json()
data = response["data"]
return [x["embedding"] for x in data]
# 调用文心4.0对话接口
def get_completion_ernie(prompt):
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + get_access_token()
payload = json.dumps({
"messages": [
{
"role": "user",
"content": prompt
}
]
})
headers = {'Content-Type': 'application/json'}
response = requests.request(
"POST", url, headers=headers, data=payload).json()
return response["result"]
# 创建一个向量数据库对象
new_vector_db = MyVectorDBConnector(
"demo_ernie",
embedding_fn=get_embeddings_bge
)
# 向向量数据库中添加文档
new_vector_db.add_documents(paragraphs)
# 创建一个RAG机器人
new_bot = RAG_Bot(
new_vector_db,
llm_api=get_completion_ernie
)
user_query = "how many parameters does llama 2 have?"
response = new_bot.chat(user_query)
print(response)
拓展实践
1. 优化 Access Token 管理
- 缓存 Token:减少鉴权接口调用次数,仅在 Token 过期时刷新。
- 示例代码:
from datetime import datetime, timedelta class TokenManager: _token = None _expires_at = None @classmethod def get_token(cls): if cls._token is None or datetime.now() > cls._expires_at: cls._refresh_token() return cls._token @classmethod def _refresh_token(cls): url = "https://aip.baidubce.com/oauth/2.0/token" params = { "grant_type": "client_credentials", "client_id": os.getenv('ERNIE_CLIENT_ID'), "client_secret": os.getenv('ERNIE_CLIENT_SECRET') } response = requests.post(url, params=params) response.raise_for_status() data = response.json() cls._token = data["access_token"] # 默认 Token 有效期为 30 天,但建议按实际返回的 expires_in 设置 cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300) # 提前 5 分钟刷新
2. 增强错误处理与重试
- 重试网络请求:使用
tenacity
库自动重试失败请求。 - 捕获异常:明确处理常见错误(如网络超时、无效响应)。
- 示例代码:
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type import requests.exceptions as req_exceptions @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError)) ) def safe_api_request(url, headers, payload): try: response = requests.post(url, headers=headers, data=payload, timeout=10) response.raise_for_status() return response.json() except req_exceptions.HTTPError as e: if response.status_code == 401: TokenManager._refresh_token() # Token 可能过期,强制刷新 raise raise ValueError(f"API 错误: {e.response.text}")
3. 验证环境变量
- 启动时检查:确保关键配置已正确设置。
- 示例代码:
def validate_env_vars(): required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET'] missing_vars = [var for var in required_vars if not os.getenv(var)] if missing_vars: raise EnvironmentError(f"缺少环境变量: {', '.join(missing_vars)}") # 在程序初始化时调用 validate_env_vars()
4. 优化向量数据库交互
- 批量插入文档:减少 API 调用次数。
- 分块策略:根据 Embedding 模型的最大输入长度分块文本。
- 示例优化(假设使用
MyVectorDBConnector
):class MyVectorDBConnector: def __init__(self, name, embedding_fn, chunk_size=512): self.embedding_fn = embedding_fn self.chunk_size = chunk_size # 根据模型支持的最大长度设置 def add_documents(self, documents): chunks = self._chunk_documents(documents) embeddings = self.embedding_fn(chunks) # 批量存储到向量数据库 def _chunk_documents(self, documents): # 实现基于句子或固定长度的分块逻辑 pass
优化后的代码示例
整合上述改进后的核心逻辑:
import os
import json
import logging
from datetime import datetime, timedelta
import requests
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import requests.exceptions as req_exceptions
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 环境变量校验
def validate_env_vars():
required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise EnvironmentError(f"Missing env vars: {', '.join(missing_vars)}")
validate_env_vars()
# Token 管理
class TokenManager:
_token = None
_expires_at = None
@classmethod
def get_token(cls):
if cls._token is None or datetime.now() > cls._expires_at:
cls._refresh_token()
return cls._token
@classmethod
def _refresh_token(cls):
logger.info("Refreshing access token...")
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": os.getenv('ERNIE_CLIENT_ID'),
"client_secret": os.getenv('ERNIE_CLIENT_SECRET')
}
response = requests.post(url, params=params)
response.raise_for_status()
data = response.json()
cls._token = data["access_token"]
cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)
# 安全 API 请求
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
)
def safe_api_request(url, headers, payload):
try:
response = requests.post(url, headers=headers, data=payload, timeout=10)
response.raise_for_status()
return response.json()
except req_exceptions.HTTPError as e:
if response.status_code == 401:
TokenManager._refresh_token()
raise
logger.error(f"API Error: {e.response.text}")
raise
# 公共 API 调用封装
def call_ernie_api(endpoint, payload):
base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
url = f"{base_url}/{endpoint}?access_token={TokenManager.get_token()}"
headers = {'Content-Type': 'application/json'}
return safe_api_request(url, headers, json.dumps(payload))
# Embedding 接口
def get_embeddings_bge(prompts):
logger.info(f"Generating embeddings for {len(prompts)} prompts")
response = call_ernie_api("embeddings/bge_large_en", {"input": prompts})
return [x["embedding"] for x in response["data"]]
# 文心 4.0 对话接口
def get_completion_ernie(prompt):
logger.info(f"Generating completion for prompt: {prompt[:50]}...")
response = call_ernie_api("chat/completions_pro", {
"messages": [{"role": "user", "content": prompt}]
})
return response["result"]