当前位置: 首页 > news >正文

【RAG】基于向量检索的 RAG (BGE示例)

RAG机器人 结构体

  • 文本向量化: 使用 BGE 模型将文档和查询编码为向量。
    (BGE 是专为检索任务优化的开源 Embedding 模型,除了本文API调用,也可以通过Hugging Face 本地部署BGE 开源模型)

  • 向量检索: 从数据库中找到与查询相关的文档片段。

  • 答案生成: 结合检索结果和用户输入,调用文心模型生成最终回答。

class RAG_Bot:
    def __init__(self, vector_db, llm_api, n_results=2):
        self.vector_db = vector_db
        self.llm_api = llm_api
        self.n_results = n_results

    def chat(self, user_query):
        # 1. 检索
        search_results = self.vector_db.search(user_query, self.n_results)

        # 2. 构建 Prompt
        prompt = build_prompt(
            prompt_template, context=search_results['documents'][0], query=user_query)

        # 3. 调用 LLM
        response = self.llm_api(prompt)
        return response
######

# 创建一个RAG机器人
bot = RAG_Bot(
    vector_db,
    llm_api=get_completion
)

user_query = "llama 2有多少参数?"

response = bot.chat(user_query)

print(response)

#####
llama 2有7B, 13B和70B参数。

MyVectorDBConnector:

自定义向量数据库,存储文档向量。
embedding_fn=get_embeddings_bge: 使用 BGE 模型生成向量。
add_documents(paragraphs): 向数据库中添加文档(已提前定义 paragraphs)。

RAG_Bot:

检索增强生成机器人,结合向量搜索与大模型生成。
chat(user_query): 执行“检索→生成”流程:
将用户查询向量化。
从数据库检索相关文档。
将检索结果作为上下文,调用文心模型生成回答。

使用国产模型

import json
import requests
import os

# 通过鉴权接口获取 access token


def get_access_token():
    """
    使用 AK,SK 生成鉴权签名(Access Token)
    :return: access_token,或是None(如果错误)
    """
    url = "https://aip.baidubce.com/oauth/2.0/token"
    params = {
        "grant_type": "client_credentials",
        "client_id": os.getenv('ERNIE_CLIENT_ID'),
        "client_secret": os.getenv('ERNIE_CLIENT_SECRET')
    }

    return str(requests.post(url, params=params).json().get("access_token"))

# 调用文心千帆 调用 BGE Embedding 接口


def get_embeddings_bge(prompts):
    url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en?access_token=" + get_access_token()
    payload = json.dumps({
        "input": prompts
    })
    headers = {'Content-Type': 'application/json'}

    response = requests.request(
        "POST", url, headers=headers, data=payload).json()
    data = response["data"]
    return [x["embedding"] for x in data]


# 调用文心4.0对话接口
def get_completion_ernie(prompt):

    url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + get_access_token()
    payload = json.dumps({
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ]
    })

    headers = {'Content-Type': 'application/json'}

    response = requests.request(
        "POST", url, headers=headers, data=payload).json()

    return response["result"]

# 创建一个向量数据库对象
new_vector_db = MyVectorDBConnector(
    "demo_ernie",
    embedding_fn=get_embeddings_bge
)
# 向向量数据库中添加文档
new_vector_db.add_documents(paragraphs)

# 创建一个RAG机器人
new_bot = RAG_Bot(
    new_vector_db,
    llm_api=get_completion_ernie
)

user_query = "how many parameters does llama 2 have?"

response = new_bot.chat(user_query)

print(response)

拓展实践

1. 优化 Access Token 管理
  • 缓存 Token:减少鉴权接口调用次数,仅在 Token 过期时刷新。
  • 示例代码
    from datetime import datetime, timedelta
    
    class TokenManager:
        _token = None
        _expires_at = None
    
        @classmethod
        def get_token(cls):
            if cls._token is None or datetime.now() > cls._expires_at:
                cls._refresh_token()
            return cls._token
    
        @classmethod
        def _refresh_token(cls):
            url = "https://aip.baidubce.com/oauth/2.0/token"
            params = {
                "grant_type": "client_credentials",
                "client_id": os.getenv('ERNIE_CLIENT_ID'),
                "client_secret": os.getenv('ERNIE_CLIENT_SECRET')
            }
            response = requests.post(url, params=params)
            response.raise_for_status()
            data = response.json()
            cls._token = data["access_token"]
            # 默认 Token 有效期为 30 天,但建议按实际返回的 expires_in 设置
            cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)  # 提前 5 分钟刷新
    
2. 增强错误处理与重试
  • 重试网络请求:使用 tenacity 库自动重试失败请求。
  • 捕获异常:明确处理常见错误(如网络超时、无效响应)。
  • 示例代码
    from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
    import requests.exceptions as req_exceptions
    
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
    )
    def safe_api_request(url, headers, payload):
        try:
            response = requests.post(url, headers=headers, data=payload, timeout=10)
            response.raise_for_status()
            return response.json()
        except req_exceptions.HTTPError as e:
            if response.status_code == 401:
                TokenManager._refresh_token()  # Token 可能过期,强制刷新
                raise
            raise ValueError(f"API 错误: {e.response.text}")
    
3. 验证环境变量
  • 启动时检查:确保关键配置已正确设置。
  • 示例代码
    def validate_env_vars():
        required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']
        missing_vars = [var for var in required_vars if not os.getenv(var)]
        if missing_vars:
            raise EnvironmentError(f"缺少环境变量: {', '.join(missing_vars)}")
    
    # 在程序初始化时调用
    validate_env_vars()
    
4. 优化向量数据库交互
  • 批量插入文档:减少 API 调用次数。
  • 分块策略:根据 Embedding 模型的最大输入长度分块文本。
  • 示例优化(假设使用 MyVectorDBConnector):
    class MyVectorDBConnector:
        def __init__(self, name, embedding_fn, chunk_size=512):
            self.embedding_fn = embedding_fn
            self.chunk_size = chunk_size  # 根据模型支持的最大长度设置
    
        def add_documents(self, documents):
            chunks = self._chunk_documents(documents)
            embeddings = self.embedding_fn(chunks)
            # 批量存储到向量数据库
    
        def _chunk_documents(self, documents):
            # 实现基于句子或固定长度的分块逻辑
            pass
    

优化后的代码示例

整合上述改进后的核心逻辑:

import os
import json
import logging
from datetime import datetime, timedelta
import requests
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import requests.exceptions as req_exceptions

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 环境变量校验
def validate_env_vars():
    required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']
    missing_vars = [var for var in required_vars if not os.getenv(var)]
    if missing_vars:
        raise EnvironmentError(f"Missing env vars: {', '.join(missing_vars)}")
validate_env_vars()

# Token 管理
class TokenManager:
    _token = None
    _expires_at = None

    @classmethod
    def get_token(cls):
        if cls._token is None or datetime.now() > cls._expires_at:
            cls._refresh_token()
        return cls._token

    @classmethod
    def _refresh_token(cls):
        logger.info("Refreshing access token...")
        url = "https://aip.baidubce.com/oauth/2.0/token"
        params = {
            "grant_type": "client_credentials",
            "client_id": os.getenv('ERNIE_CLIENT_ID'),
            "client_secret": os.getenv('ERNIE_CLIENT_SECRET')
        }
        response = requests.post(url, params=params)
        response.raise_for_status()
        data = response.json()
        cls._token = data["access_token"]
        cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)

# 安全 API 请求
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
)
def safe_api_request(url, headers, payload):
    try:
        response = requests.post(url, headers=headers, data=payload, timeout=10)
        response.raise_for_status()
        return response.json()
    except req_exceptions.HTTPError as e:
        if response.status_code == 401:
            TokenManager._refresh_token()
            raise
        logger.error(f"API Error: {e.response.text}")
        raise

# 公共 API 调用封装
def call_ernie_api(endpoint, payload):
    base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
    url = f"{base_url}/{endpoint}?access_token={TokenManager.get_token()}"
    headers = {'Content-Type': 'application/json'}
    return safe_api_request(url, headers, json.dumps(payload))

# Embedding 接口
def get_embeddings_bge(prompts):
    logger.info(f"Generating embeddings for {len(prompts)} prompts")
    response = call_ernie_api("embeddings/bge_large_en", {"input": prompts})
    return [x["embedding"] for x in response["data"]]

# 文心 4.0 对话接口
def get_completion_ernie(prompt):
    logger.info(f"Generating completion for prompt: {prompt[:50]}...")
    response = call_ernie_api("chat/completions_pro", {
        "messages": [{"role": "user", "content": prompt}]
    })
    return response["result"]

相关文章:

  • Leetcode 刷题记录 05 —— 普通数组
  • 硬件学习笔记--48 磁保持继电器相关基础知识介绍
  • 【每日学点HarmonyOS Next知识】 状态变量、公共Page、可见区域变化回调、接收参数、拖拽排序控件
  • 前端数据模拟 Mock.js 学习笔记(附带详细)
  • 中小学信息学特长生试卷(C++)
  • 6.聊天室环境安装 - Ubuntu22.04 - elasticsearch(es)的安装和使用
  • clickhouse执行进度
  • How to install nacos 2.5 with podman
  • 汇编的伪指令
  • Vue3 模板引用:打破数据驱动的次元壁(附高阶玩法)
  • openwrt路由系统------lua、uci的关系
  • SAP HANA Merge
  • 【C++设计模式】第十六篇:迭代器模式(Iterator)
  • mysql进阶(五)
  • Windows控制台函数:控制台读取输入函数ReadConsoleA()
  • STM32中输入/输出有无默认电平
  • C++的内存管理
  • 单片机项目复刻需要的准备工作
  • SpringBoot参数校验:@Valid 与 @Validated 详解
  • nginx反向代理功能
  • 冰雹造成车损能赔吗?如何理赔?机构答疑
  • 重庆市委原常委、政法委原书记陆克华被决定逮捕
  • 中巡组在行动丨①震慑:这些地区有官员落马
  • 郑培凯:汤显祖的“至情”与罗汝芳的“赤子之心”
  • 筑牢安全防线、提升应急避难能力水平,5项国家标准发布
  • 中国女足将于5月17日至6月2日赴美国集训并参加邀请赛