RAG 多模态 API 处理系统设计解析:企业级大模型集成架构实战
附源码在最后
前言
在现代 RAG 系统中,大语言模型 API 的集成是核心环节之一。本文将深入解析一个获得 RAG 挑战赛冠军的多模态 API 处理系统实现,该系统优雅地统一了 OpenAI、IBM、Google Gemini、阿里云 DashScope 等多个主流 LLM 提供商的接口,通过统一的抽象层、智能重试机制、结构化输出处理等技术,实现了高可用、高性能的企业级大模型服务集成。
系统架构概览
该 API 处理系统采用了分层抽象的设计模式:
-
统一接口层:
APIProcessor
- 对外提供统一的调用接口 -
提供商适配层:各个
Base*Processor
- 适配不同 LLM 提供商的 API 差异 -
功能增强层:结构化输出、重试机制、异步处理等高级功能
-
工具支持层:Token 计算、JSON 修复、提示词管理等辅助功能
核心组件详解
1. 统一 API 处理器(APIProcessor)
APIProcessor 是系统的门面类,提供统一的多提供商 LLM 访问接口:
class APIProcessor:def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):self.provider = provider.lower()if self.provider == "openai":self.processor = BaseOpenaiProcessor()elif self.provider == "ibm":self.processor = BaseIBMAPIProcessor()elif self.provider == "gemini":self.processor = BaseGeminiProcessor()elif self.provider == "dashscope":self.processor = BaseDashscopeProcessor() def send_message(self,model=None,temperature=0.5,seed=None,system_content="You are a helpful assistant.",human_content="Hello!",is_structured=False,response_format=None,**kwargs):"""统一的消息发送接口,路由到对应的处理器"""if model is None:model = self.processor.default_modelreturn self.processor.send_message(model=model,temperature=temperature,seed=seed,system_content=system_content,human_content=human_content,is_structured=is_structured,response_format=response_format,**kwargs)
设计亮点:
-
统一接口:屏蔽不同提供商的 API 差异
-
动态路由:根据配置自动选择对应的处理器
-
参数透传:支持各提供商的特有参数
-
类型安全:使用 Literal 类型确保提供商名称正确
2. OpenAI 处理器(BaseOpenaiProcessor)
OpenAI 处理器实现了对 OpenAI API 的完整封装,支持结构化输出和 Token 统计:
class BaseOpenaiProcessor:def __init__(self):self.llm = self.set_up_llm()self.default_model = 'gpt-4o-2024-08-06' def set_up_llm(self):# 加载OpenAI API密钥,初始化LLMload_dotenv()llm = OpenAI(api_key=os.getenv("OPENAI_API_KEY"),timeout=None,max_retries=2)return llm def send_message(self,model=None,temperature=0.5,seed=None,system_content='You are a helpful assistant.',human_content='Hello!',is_structured=False,response_format=None):if model is None:model = self.default_modelparams = {"model": model,"seed": seed,"messages": [{"role": "system", "content": system_content},{"role": "user", "content": human_content}]}# 部分模型不支持temperature参数if "o3-mini" not in model:params["temperature"] = temperatureif not is_structured:# 普通文本输出completion = self.llm.chat.completions.create(**params)content = completion.choices[0].message.contentelse:# 结构化输出params["response_format"] = response_formatcompletion = self.llm.beta.chat.completions.parse(**params)response = completion.choices[0].message.parsedcontent = response.dict() # 记录使用统计self.response_data = {"model": completion.model, "input_tokens": completion.usage.prompt_tokens, "output_tokens": completion.usage.completion_tokens}print(self.response_data) return content @staticmethoddef count_tokens(string, encoding_name="o200k_base"):# 统计字符串的token数encoding = tiktoken.get_encoding(encoding_name)tokens = encoding.encode(string)token_count = len(tokens)return token_count
技术特色:
-
结构化输出:支持 Pydantic 模型的结构化响应
-
模型兼容性:智能处理不同模型的参数差异
-
Token 统计:精确的 Token 使用量统计
-
错误处理:完善的超时和重试机制
3. IBM API 处理器(BaseIBMAPIProcessor)
IBM 处理器提供了对 IBM Watson 等企业级 AI 服务的集成:
class BaseIBMAPIProcessor:def __init__(self):load_dotenv()self.api_token = os.getenv("IBM_API_KEY")self.base_url = "https://rag.timetoact.at/ibm"self.default_model = 'meta-llama/llama-3-3-70b-instruct'def check_balance(self):"""查询当前API余额"""balance_url = f"{self.base_url}/balance"headers = {"Authorization": f"Bearer {self.api_token}"}try:response = requests.get(balance_url, headers=headers)response.raise_for_status()return response.json()except requests.HTTPError as err:print(f"Error checking balance: {err}")return Nonedef get_embeddings(self, texts, model_id="ibm/granite-embedding-278m-multilingual"):"""获取文本的向量嵌入"""embeddings_url = f"{self.base_url}/embeddings"headers = {"Authorization": f"Bearer {self.api_token}","Content-Type": "application/json"}payload = {"inputs": texts,"model_id": model_id}try:response = requests.post(embeddings_url, headers=headers, json=payload)response.raise_for_status()return response.json()except requests.HTTPError as err:print(f"Error getting embeddings: {err}")return None def send_message(self,model=None,temperature=0.5,seed=None,system_content='You are a helpful assistant.',human_content='Hello!',is_structured=False,response_format=None,max_new_tokens=5000,min_new_tokens=1,**kwargs):if model is None:model = self.default_modeltext_generation_url = f"{self.base_url}/text_generation"headers = {"Authorization": f"Bearer {self.api_token}","Content-Type": "application/json"}# 准备输入消息input_messages = [{"role": "system", "content": system_content},{"role": "user", "content": human_content}]# 准备参数parameters = {"temperature": temperature,"random_seed": seed,"max_new_tokens": max_new_tokens,"min_new_tokens": min_new_tokens,**kwargs}payload = {"input": input_messages,"model_id": model,"parameters": parameters}try:response = requests.post(text_generation_url, headers=headers, json=payload)response.raise_for_status()completion = response.json() content = completion.get("results")[0].get("generated_text")self.response_data = {"model": completion.get("model_id"), "input_tokens": completion.get("results")[0].get("input_token_count"), "output_tokens": completion.get("results")[0].get("generated_token_count")}# 结构化输出处理if is_structured and response_format is not None:try:repaired_json = repair_json(content)parsed_dict = json.loads(repaired_json)validated_data = response_format.model_validate(parsed_dict)content = validated_data.model_dump()except Exception as err:print("Error processing structured response, attempting to reparse...")content = self._reparse_response(content, system_content)return content except requests.HTTPError as err:print(f"Error generating text: {err}")return None
企业级特性:
-
余额查询:实时监控 API 使用额度
-
嵌入服务:支持多语言文本嵌入
-
灵活参数:支持丰富的生成参数配置
-
错误恢复:智能的 JSON 修复和重解析机制
4. Google Gemini 处理器(BaseGeminiProcessor)
Gemini 处理器集成了 Google 的最新大模型服务:
class BaseGeminiProcessor:def __init__(self):self.llm = self._set_up_llm()self.default_model = 'gemini-2.0-flash-001'def _set_up_llm(self):load_dotenv()api_key = os.getenv("GEMINI_API_KEY")genai.configure(api_key=api_key)return genai @retry(wait=wait_fixed(20),stop=stop_after_attempt(3),before_sleep=lambda retry_state: print(f"\nAPI Error: {retry_state.outcome.exception()}\nWaiting 20 seconds...\n"),)def _generate_with_retry(self, model, human_content, generation_config):"""带重试机制的内容生成"""try:return model.generate_content(human_content,generation_config=generation_config)except Exception as e:if getattr(e, '_attempt_number', 0) == 3:print(f"\nRetry failed. Error: {str(e)}\n")raise def send_message(self,model=None,temperature: float = 0.5,seed=12345,system_content: str = "You are a helpful assistant.",human_content: str = "Hello!",is_structured: bool = False,response_format: Optional[Type[BaseModel]] = None,) -> Union[str, Dict, None]:if model is None:model = self.default_model generation_config = {"temperature": temperature}# Gemini 使用单一提示词格式prompt = f"{system_content}\n\n---\n\n{human_content}" model_instance = self.llm.GenerativeModel(model_name=model,generation_config=generation_config) try:response = self._generate_with_retry(model_instance, prompt, generation_config) self.response_data = {"model": response.model_version,"input_tokens": response.usage_metadata.prompt_token_count,"output_tokens": response.usage_metadata.candidates_token_count}if is_structured and response_format is not None:return self._parse_structured_response(response.text, response_format)return response.textexcept Exception as e:raise Exception(f"API request failed after retries: {str(e)}") def _parse_structured_response(self, response_text, response_format):"""解析结构化响应"""try:repaired_json = repair_json(response_text)parsed_dict = json.loads(repaired_json)validated_data = response_format.model_validate(parsed_dict)return validated_data.model_dump()except Exception as err:print(f"Error parsing structured response: {err}")return self._reparse_response(response_text, response_format)
技术亮点:
-
智能重试:使用 tenacity 库实现指数退避重试
-
提示词适配:适配 Gemini 的单一提示词格式
-
使用统计:详细的 Token 使用量统计
-
结构化解析:智能的 JSON 解析和修复
5. 阿里云 DashScope 处理器(BaseDashscopeProcessor)
DashScope 处理器集成了阿里云的通义千问大模型:
class BaseDashscopeProcessor:def __init__(self):# 从环境变量读取API-KEYdashscope.api_key = os.getenv("DASHSCOPE_API_KEY")self.default_model = 'qwen-turbo-latest' def send_message(self,model="qwen-turbo-latest",temperature=0.1,seed=None,system_content='You are a helpful assistant.',human_content='Hello!',is_structured=False,response_format=None,**kwargs):"""发送消息到DashScope Qwen大模型"""if model is None:model = self.default_model# 拼接 messagesmessages = []if system_content:messages.append({"role": "system", "content": system_content})if human_content:messages.append({"role": "user", "content": human_content})# 调用 dashscope Generation.callresponse = dashscope.Generation.call(model=model,messages=messages,temperature=temperature,result_format='message')# 兼容统一接口格式if hasattr(response, 'output') and hasattr(response.output, 'choices'):content = response.output.choices[0].message.contentelse:content = str(response)# 保持接口一致性self.response_data = {"model": model, "input_tokens": None, "output_tokens": None}# 统一返回格式return {"final_answer": content}
国产化特色:
-
本土优化:针对中文场景优化的通义千问模型
-
简化接口:简洁的 API 调用方式
-
兼容设计:保持与其他提供商的接口一致性
-
成本优势:相对较低的使用成本
6. RAG 上下文处理
系统提供了专门的 RAG 上下文处理功能:
def get_answer_from_rag_context(self, question, rag_context, schema, model):"""从RAG上下文生成答案"""system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)answer_dict = self.processor.send_message(model=model,system_content=system_prompt,human_content=user_prompt.format(context=rag_context, question=question),is_structured=True,response_format=response_format)self.response_data = self.processor.response_data# 兜底处理:确保返回完整的答案结构if 'step_by_step_analysis' not in answer_dict:answer_dict = {"step_by_step_analysis": "","reasoning_summary": "","relevant_pages": [],"final_answer": answer_dict.get("final_answer", "N/A")}return answer_dict def _build_rag_context_prompts(self, schema):"""根据答案类型构建提示词"""use_schema_prompt = True if self.provider == "ibm" or self.provider == "gemini" else Falseif schema == "name":system_prompt = (prompts.AnswerWithRAGContextNamePrompt.system_prompt_with_schema if use_schema_prompt else prompts.AnswerWithRAGContextNamePrompt.system_prompt)response_format = prompts.AnswerWithRAGContextNamePrompt.AnswerSchemauser_prompt = prompts.AnswerWithRAGContextNamePrompt.user_promptelif schema == "number":system_prompt = (prompts.AnswerWithRAGContextNumberPrompt.system_prompt_with_schemaif use_schema_prompt else prompts.AnswerWithRAGContextNumberPrompt.system_prompt)response_format = prompts.AnswerWithRAGContextNumberPrompt.AnswerSchemauser_prompt = prompts.AnswerWithRAGContextNumberPrompt.user_prompt# ... 其他类型else:raise ValueError(f"Unsupported schema: {schema}")return system_prompt, response_format, user_prompt
RAG 特色功能:
-
多类型支持:支持 name、number、boolean、names、comparative 等多种答案类型
-
智能提示词:根据提供商特性选择最优提示词格式
-
结构化输出:确保答案包含推理过程和引用信息
-
兜底机制:处理不完整响应的智能补全
7. 异步批量处理
系统还提供了高性能的异步批量处理功能:
class AsyncOpenaiProcessor:async def process_structured_ouputs_requests(self,model="gpt-4o-mini-2024-07-18",temperature=0.5,seed=None,system_content="You are a helpful assistant.",queries=None,response_format=None,requests_filepath='./temp_async_llm_requests.jsonl',save_filepath='./temp_async_llm_results.jsonl',max_requests_per_minute=3_500,max_tokens_per_minute=3_500_000,progress_callback=None):# 创建批量请求jsonl_requests = []for idx, query in enumerate(queries):request = {"model": model,"temperature": temperature,"seed": seed,"messages": [{"role": "system", "content": system_content},{"role": "user", "content": query},],'response_format': type_to_response_format_param(response_format),'metadata': {'original_index': idx}}jsonl_requests.append(request)# 写入JSONL文件with open(requests_filepath, "w") as f:for request in jsonl_requests:json_string = json.dumps(request)f.write(json_string + "\n") # 异步处理和进度监控async def monitor_progress():last_count = 0while True:try:with open(save_filepath, 'r') as f:current_count = sum(1 for _ in f)if current_count > last_count:if progress_callback:for _ in range(current_count - last_count):progress_callback()last_count = current_countif current_count >= len(jsonl_requests):breakexcept FileNotFoundError:passawait asyncio.sleep(0.1) # 并行执行处理和监控await asyncio.gather(process_api_requests_from_file(requests_filepath=requests_filepath,save_filepath=save_filepath,request_url="https://api.openai.com/v1/chat/completions",api_key=os.getenv("OPENAI_API_KEY"),max_requests_per_minute=max_requests_per_minute,max_tokens_per_minute=max_tokens_per_minute,max_attempts=5),monitor_progress())# 解析结果并排序with open(save_filepath, "r") as f:results = []for line_number, line in enumerate(f, start=1):try:result = json.loads(line.strip())answer_content = result[1]['choices'][0]['message']['content']answer_parsed = json.loads(answer_content)answer = response_format(**answer_parsed).model_dump()results.append({'index': result[2]['original_index'],'question': result[0]['messages'],'answer': answer})except Exception as e:print(f"[ERROR] Line {line_number}: Failed to parse. Error: {e}")# 按原始顺序排序validated_data_list = [{'question': r['question'], 'answer': r['answer']} for r in sorted(results, key=lambda x: x['index'])]return validated_data_list
异步处理优势:
-
高并发:支持每分钟数千次请求
-
进度监控:实时显示处理进度
-
结果排序:保持输出顺序与输入一致
-
错误处理:完善的异常处理和日志记录
智能错误处理与修复
JSON 修复机制
系统实现了智能的 JSON 修复机制:
def _reparse_response(self, response, system_content):"""使用LLM重新解析无效的JSON响应"""user_prompt = prompts.AnswerSchemaFixPrompt.user_prompt.format(system_prompt=system_content,response=response)reparsed_response = self.send_message(system_content=prompts.AnswerSchemaFixPrompt.system_prompt,human_content=user_prompt,is_structured=False)try:repaired_json = repair_json(reparsed_response)reparsed_dict = json.loads(repaired_json)validated_data = response_format.model_validate(reparsed_dict)print("Reparsing successful!")return validated_data.model_dump()except Exception as reparse_err:print(f"Reparse failed with error: {reparse_err}")return response
修复策略:
-
自动修复:使用 json_repair 库自动修复常见 JSON 错误
-
LLM 重解析:当自动修复失败时,使用 LLM 重新格式化
-
多层兜底:提供多层错误处理机制
-
日志记录:详细记录修复过程和结果
实际应用场景
1. 企业多云部署
# 配置多个提供商作为备份 primary_processor = APIProcessor(provider="openai") backup_processor = APIProcessor(provider="dashscope")def robust_api_call(question, context, schema):try:return primary_processor.get_answer_from_rag_context(question=question, rag_context=context, schema=schema, model="gpt-4o-2024-08-06")except Exception as e:print(f"Primary API failed: {e}, switching to backup...")return backup_processor.get_answer_from_rag_context(question=question, rag_context=context, schema=schema, model="qwen-turbo-latest")
2. 成本优化策略
# 根据问题复杂度选择合适的模型 def cost_optimized_processing(question, context, schema):# 简单问题使用成本较低的模型if len(question) < 50 and schema in ["boolean", "name"]:processor = APIProcessor(provider="dashscope")model = "qwen-turbo-latest"else:# 复杂问题使用高性能模型processor = APIProcessor(provider="openai")model = "gpt-4o-2024-08-06"return processor.get_answer_from_rag_context(question=question,rag_context=context,schema=schema,model=model)
3. 批量处理优化
# 大规模批量处理 async def batch_process_questions(questions_list):processor = AsyncOpenaiProcessor()queries = [q["question"] for q in questions_list]results = await processor.process_structured_ouputs_requests(model="gpt-4o-mini-2024-07-18",system_content="You are a helpful RAG assistant.",queries=queries,response_format=AnswerSchema,max_requests_per_minute=1000,progress_callback=lambda: print(".", end=""))return results
性能优化策略
1. 连接池管理
# 优化HTTP连接 import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retrydef setup_session():session = requests.Session()retry_strategy = Retry(total=3,backoff_factor=1,status_forcelist=[429, 500, 502, 503, 504],)adapter = HTTPAdapter(max_retries=retry_strategy, pool_connections=20, pool_maxsize=20)session.mount("http://", adapter)session.mount("https://", adapter)return session
2. 缓存机制
from functools import lru_cache import hashlibclass CachedAPIProcessor(APIProcessor):@lru_cache(maxsize=1000)def cached_send_message(self, content_hash, **kwargs):return super().send_message(**kwargs)def send_message(self, **kwargs):# 生成内容哈希用于缓存content = f"{kwargs.get('system_content', '')}{kwargs.get('human_content', '')}"content_hash = hashlib.md5(content.encode()).hexdigest()return self.cached_send_message(content_hash, **kwargs)
3. 负载均衡
import random from typing import Listclass LoadBalancedAPIProcessor:def __init__(self, providers: List[str]):self.processors = [APIProcessor(provider=p) for p in providers]self.weights = [1.0] * len(self.processors) # 可根据性能调整权重def send_message(self, **kwargs):# 根据权重随机选择处理器processor = random.choices(self.processors, weights=self.weights)[0]try:return processor.send_message(**kwargs)except Exception as e:# 降低失败处理器的权重idx = self.processors.index(processor)self.weights[idx] *= 0.8# 重试其他处理器for other_processor in self.processors:if other_processor != processor:try:return other_processor.send_message(**kwargs)except:continueraise e
监控与调试
1. 详细日志记录
import logging from datetime import datetimeclass LoggedAPIProcessor(APIProcessor):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.logger = logging.getLogger(f"APIProcessor-{self.provider}")def send_message(self, **kwargs):start_time = datetime.now()try:result = super().send_message(**kwargs)end_time = datetime.now()self.logger.info(f"API call successful - Provider: {self.provider}, "f"Duration: {(end_time - start_time).total_seconds():.2f}s, "f"Tokens: {self.response_data}")return resultexcept Exception as e:end_time = datetime.now()self.logger.error(f"API call failed - Provider: {self.provider}, "f"Duration: {(end_time - start_time).total_seconds():.2f}s, "f"Error: {str(e)}")raise
2. 性能监控
import time from collections import defaultdictclass PerformanceMonitor:def __init__(self):self.stats = defaultdict(list)def record_api_call(self, provider, duration, tokens_used, success):self.stats[provider].append({'duration': duration,'tokens': tokens_used,'success': success,'timestamp': time.time()})def get_stats(self, provider=None):if provider:calls = self.stats[provider]else:calls = []for provider_calls in self.stats.values():calls.extend(provider_calls)if not calls:return {}successful_calls = [c for c in calls if c['success']]return {'total_calls': len(calls),'success_rate': len(successful_calls) / len(calls),'avg_duration': sum(c['duration'] for c in successful_calls) / len(successful_calls) if successful_calls else 0,'total_tokens': sum(c['tokens'].get('input_tokens', 0) + c['tokens'].get('output_tokens', 0) for c in successful_calls if c['tokens'])}
最佳实践建议
1. 提供商选择策略
# 根据场景选择最优提供商 def choose_optimal_provider(task_type, budget_level, latency_requirement):if budget_level == "low":return "dashscope" # 成本优势elif latency_requirement == "ultra_low":return "openai" # 响应速度快elif task_type == "multilingual":return "gemini" # 多语言支持好elif task_type == "enterprise":return "ibm" # 企业级特性else:return "openai" # 默认选择
2. 错误处理最佳实践
from tenacity import retry, stop_after_attempt, wait_exponential@retry(stop=stop_after_attempt(3),wait=wait_exponential(multiplier=1, min=4, max=10) ) def robust_api_call(processor, **kwargs):try:return processor.send_message(**kwargs)except Exception as e:if "rate_limit" in str(e).lower():time.sleep(60) # 等待限流恢复raise
3. 配置管理
import yamlclass APIConfig:def __init__(self, config_file="api_config.yaml"):with open(config_file, 'r') as f:self.config = yaml.safe_load(f)def get_provider_config(self, provider):return self.config.get('providers', {}).get(provider, {})def get_model_for_task(self, task_type):return self.config.get('task_models', {}).get(task_type, "default")# api_config.yaml 示例 """ providers:openai:default_model: "gpt-4o-2024-08-06"max_retries: 3timeout: 30dashscope:default_model: "qwen-turbo-latest"max_retries: 2timeout: 20task_models:simple_qa: "gpt-4o-mini-2024-07-18"complex_analysis: "gpt-4o-2024-08-06"multilingual: "gemini-2.0-flash-001" """
总结
这个多模态 API 处理系统展示了企业级 LLM 集成的最佳实践:
-
统一抽象:优雅的多提供商统一接口设计
-
智能适配:针对不同提供商的特性优化
-
错误恢复:完善的重试、修复和兜底机制
-
性能优化:异步处理、连接池、缓存等优化策略
-
监控调试:详细的日志记录和性能监控
-
扩展性:易于添加新提供商和新功能
-
企业特性:余额查询、嵌入服务等企业级功能
对于构建企业级 RAG 系统,这个 API 处理架构提供了完整的参考实现。通过合理的设计和优化,可以在保证服务质量的同时,实现高可用、高性能、低成本的大模型服务集成。
参考资源
-
OpenAI API 文档
-
Google Gemini API 文档
-
阿里云 DashScope 文档
-
IBM Watson API 文档
-
Tenacity 重试库
本文基于 RAG-Challenge-2 获奖项目的 API 处理模块源码分析,展示了工业级多模态 LLM 集成的完整实现和优化策略。希望对正在构建类似系统的开发者有所帮助。
import os
import json
from dotenv import load_dotenv
from typing import Union, List, Dict, Type, Optional, Literal
from openai import OpenAI
import asyncio
from src.api_request_parallel_processor import process_api_requests_from_file
from openai.lib._parsing import type_to_response_format_param
import tiktoken
import src.prompts as prompts
import requests
from json_repair import repair_json
from pydantic import BaseModel
import google.generativeai as genai
from copy import deepcopy
from tenacity import retry, stop_after_attempt, wait_fixed
import dashscope# OpenAI基础处理器,封装了消息发送、结构化输出、计费等逻辑
class BaseOpenaiProcessor:def __init__(self):self.llm = self.set_up_llm()self.default_model = 'gpt-4o-2024-08-06'# self.default_model = 'gpt-4o-mini-2024-07-18',def set_up_llm(self):# 加载OpenAI API密钥,初始化LLMload_dotenv()llm = OpenAI(api_key=os.getenv("OPENAI_API_KEY"),timeout=None,max_retries=2)return llmdef send_message(self,model=None,temperature=0.5,seed=None, # For deterministic ouptputssystem_content='You are a helpful assistant.',human_content='Hello!',is_structured=False,response_format=None):# 发送消息到OpenAI,支持结构化/非结构化输出if model is None:model = self.default_modelparams = {"model": model,"seed": seed,"messages": [{"role": "system", "content": system_content},{"role": "user", "content": human_content}]}# 部分模型不支持temperatureif "o3-mini" not in model:params["temperature"] = temperatureif not is_structured:completion = self.llm.chat.completions.create(**params)content = completion.choices[0].message.contentelif is_structured:params["response_format"] = response_formatcompletion = self.llm.beta.chat.completions.parse(**params)response = completion.choices[0].message.parsedcontent = response.dict()self.response_data = {"model": completion.model, "input_tokens": completion.usage.prompt_tokens, "output_tokens": completion.usage.completion_tokens}print(self.response_data)return content@staticmethoddef count_tokens(string, encoding_name="o200k_base"):# 统计字符串的token数encoding = tiktoken.get_encoding(encoding_name)# Encode the string and count the tokenstokens = encoding.encode(string)token_count = len(tokens)return token_count# IBM API基础处理器,支持余额查询、模型列表、嵌入、消息发送等
class BaseIBMAPIProcessor:def __init__(self):load_dotenv()self.api_token = os.getenv("IBM_API_KEY")self.base_url = "https://rag.timetoact.at/ibm"self.default_model = 'meta-llama/llama-3-3-70b-instruct'def check_balance(self):"""查询当前API余额"""balance_url = f"{self.base_url}/balance"headers = {"Authorization": f"Bearer {self.api_token}"}try:response = requests.get(balance_url, headers=headers)response.raise_for_status()return response.json()except requests.HTTPError as err:print(f"Error checking balance: {err}")return Nonedef get_available_models(self):"""获取可用基础模型列表"""models_url = f"{self.base_url}/foundation_model_specs"try:response = requests.get(models_url)response.raise_for_status()return response.json()except requests.HTTPError as err:print(f"Error getting available models: {err}")return Nonedef get_embeddings(self, texts, model_id="ibm/granite-embedding-278m-multilingual"):"""获取文本的向量嵌入"""embeddings_url = f"{self.base_url}/embeddings"headers = {"Authorization": f"Bearer {self.api_token}","Content-Type": "application/json"}payload = {"inputs": texts,"model_id": model_id}try:response = requests.post(embeddings_url, headers=headers, json=payload)response.raise_for_status()return response.json()except requests.HTTPError as err:print(f"Error getting embeddings: {err}")return Nonedef send_message(self,# model='meta-llama/llama-3-1-8b-instruct',model=None,temperature=0.5,seed=None, # For deterministic outputssystem_content='You are a helpful assistant.',human_content='Hello!',is_structured=False,response_format=None,max_new_tokens=5000,min_new_tokens=1,**kwargs):# 发送消息到IBM API,支持结构化/非结构化输出if model is None:model = self.default_modeltext_generation_url = f"{self.base_url}/text_generation"headers = {"Authorization": f"Bearer {self.api_token}","Content-Type": "application/json"}# Prepare the input messagesinput_messages = [{"role": "system", "content": system_content},{"role": "user", "content": human_content}]# Prepare parameters with defaults and any additional parametersparameters = {"temperature": temperature,"random_seed": seed,"max_new_tokens": max_new_tokens,"min_new_tokens": min_new_tokens,**kwargs}payload = {"input": input_messages,"model_id": model,"parameters": parameters}try:response = requests.post(text_generation_url, headers=headers, json=payload)response.raise_for_status()completion = response.json()content = completion.get("results")[0].get("generated_text")self.response_data = {"model": completion.get("model_id"), "input_tokens": completion.get("results")[0].get("input_token_count"), "output_tokens": completion.get("results")[0].get("generated_token_count")}print(self.response_data)if is_structured and response_format is not None:try:repaired_json = repair_json(content)parsed_dict = json.loads(repaired_json)validated_data = response_format.model_validate(parsed_dict)content = validated_data.model_dump()return contentexcept Exception as err:print("Error processing structured response, attempting to reparse the response...")reparsed = self._reparse_response(content, system_content)try:repaired_json = repair_json(reparsed)reparsed_dict = json.loads(repaired_json)try:validated_data = response_format.model_validate(reparsed_dict)print("Reparsing successful!")content = validated_data.model_dump()return contentexcept Exception:return reparsed_dictexcept Exception as reparse_err:print(f"Reparse failed with error: {reparse_err}")print(f"Reparsed response: {reparsed}")return contentreturn contentexcept requests.HTTPError as err:print(f"Error generating text: {err}")return Nonedef _reparse_response(self, response, system_content):user_prompt = prompts.AnswerSchemaFixPrompt.user_prompt.format(system_prompt=system_content,response=response)reparsed_response = self.send_message(system_content=prompts.AnswerSchemaFixPrompt.system_prompt,human_content=user_prompt,is_structured=False)return reparsed_responseclass BaseGeminiProcessor:def __init__(self):self.llm = self._set_up_llm()self.default_model = 'gemini-2.0-flash-001'# self.default_model = "gemini-2.0-flash-thinking-exp-01-21",def _set_up_llm(self):load_dotenv()api_key = os.getenv("GEMINI_API_KEY")genai.configure(api_key=api_key)return genaidef list_available_models(self) -> None:"""Prints available Gemini models that support text generation."""print("Available models for text generation:")for model in self.llm.list_models():if "generateContent" in model.supported_generation_methods:print(f"- {model.name}")print(f" Input token limit: {model.input_token_limit}")print(f" Output token limit: {model.output_token_limit}")print()def _log_retry_attempt(retry_state):"""Print information about the retry attempt"""exception = retry_state.outcome.exception()print(f"\nAPI Error encountered: {str(exception)}")print("Waiting 20 seconds before retry...\n")@retry(wait=wait_fixed(20),stop=stop_after_attempt(3),before_sleep=_log_retry_attempt,)def _generate_with_retry(self, model, human_content, generation_config):"""Wrapper for generate_content with retry logic"""try:return model.generate_content(human_content,generation_config=generation_config)except Exception as e:if getattr(e, '_attempt_number', 0) == 3:print(f"\nRetry failed. Error: {str(e)}\n")raisedef _parse_structured_response(self, response_text, response_format):try:repaired_json = repair_json(response_text)parsed_dict = json.loads(repaired_json)validated_data = response_format.model_validate(parsed_dict)return validated_data.model_dump()except Exception as err:print(f"Error parsing structured response: {err}")print("Attempting to reparse the response...")reparsed = self._reparse_response(response_text, response_format)return reparseddef _reparse_response(self, response, response_format):"""Reparse invalid JSON responses using the model itself."""user_prompt = prompts.AnswerSchemaFixPrompt.user_prompt.format(system_prompt=prompts.AnswerSchemaFixPrompt.system_prompt,response=response)try:reparsed_response = self.send_message(model="gemini-2.0-flash-001",system_content=prompts.AnswerSchemaFixPrompt.system_prompt,human_content=user_prompt,is_structured=False)try:repaired_json = repair_json(reparsed_response)reparsed_dict = json.loads(repaired_json)try:validated_data = response_format.model_validate(reparsed_dict)print("Reparsing successful!")return validated_data.model_dump()except Exception:return reparsed_dictexcept Exception as reparse_err:print(f"Reparse failed with error: {reparse_err}")print(f"Reparsed response: {reparsed_response}")return responseexcept Exception as e:print(f"Reparse attempt failed: {e}")return responsedef send_message(self,model=None,temperature: float = 0.5,seed=12345, # For back compatibilitysystem_content: str = "You are a helpful assistant.",human_content: str = "Hello!",is_structured: bool = False,response_format: Optional[Type[BaseModel]] = None,) -> Union[str, Dict, None]:if model is None:model = self.default_modelgeneration_config = {"temperature": temperature}prompt = f"{system_content}\n\n---\n\n{human_content}"model_instance = self.llm.GenerativeModel(model_name=model,generation_config=generation_config)try:response = self._generate_with_retry(model_instance, prompt, generation_config)self.response_data = {"model": response.model_version,"input_tokens": response.usage_metadata.prompt_token_count,"output_tokens": response.usage_metadata.candidates_token_count}print(self.response_data)if is_structured and response_format is not None:return self._parse_structured_response(response.text, response_format)return response.textexcept Exception as e:raise Exception(f"API request failed after retries: {str(e)}")class APIProcessor:def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):self.provider = provider.lower()if self.provider == "openai":self.processor = BaseOpenaiProcessor()elif self.provider == "ibm":self.processor = BaseIBMAPIProcessor()elif self.provider == "gemini":self.processor = BaseGeminiProcessor()elif self.provider == "dashscope":self.processor = BaseDashscopeProcessor()def send_message(self,model=None,temperature=0.5,seed=None,system_content="You are a helpful assistant.",human_content="Hello!",is_structured=False,response_format=None,**kwargs):"""Routes the send_message call to the appropriate processor.The underlying processor's send_message method is responsible for handling the parameters."""if model is None:model = self.processor.default_modelreturn self.processor.send_message(model=model,temperature=temperature,seed=seed,system_content=system_content,human_content=human_content,is_structured=is_structured,response_format=response_format,**kwargs)def get_answer_from_rag_context(self, question, rag_context, schema, model):system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)answer_dict = self.processor.send_message(model=model,system_content=system_prompt,human_content=user_prompt.format(context=rag_context, question=question),is_structured=True,response_format=response_format)self.response_data = self.processor.response_data# 假如 answer_dict 只有 final_answer,自动兜底if 'step_by_step_analysis' not in answer_dict:answer_dict = {"step_by_step_analysis": "","reasoning_summary": "","relevant_pages": [],"final_answer": answer_dict.get("final_answer", "N/A")}return answer_dictdef _build_rag_context_prompts(self, schema):"""Return prompts tuple for the given schema."""use_schema_prompt = True if self.provider == "ibm" or self.provider == "gemini" else Falseif schema == "name":system_prompt = (prompts.AnswerWithRAGContextNamePrompt.system_prompt_with_schema if use_schema_prompt else prompts.AnswerWithRAGContextNamePrompt.system_prompt)response_format = prompts.AnswerWithRAGContextNamePrompt.AnswerSchemauser_prompt = prompts.AnswerWithRAGContextNamePrompt.user_promptelif schema == "number":system_prompt = (prompts.AnswerWithRAGContextNumberPrompt.system_prompt_with_schemaif use_schema_prompt else prompts.AnswerWithRAGContextNumberPrompt.system_prompt)response_format = prompts.AnswerWithRAGContextNumberPrompt.AnswerSchemauser_prompt = prompts.AnswerWithRAGContextNumberPrompt.user_promptelif schema == "boolean":system_prompt = (prompts.AnswerWithRAGContextBooleanPrompt.system_prompt_with_schemaif use_schema_prompt else prompts.AnswerWithRAGContextBooleanPrompt.system_prompt)response_format = prompts.AnswerWithRAGContextBooleanPrompt.AnswerSchemauser_prompt = prompts.AnswerWithRAGContextBooleanPrompt.user_promptelif schema == "names":system_prompt = (prompts.AnswerWithRAGContextNamesPrompt.system_prompt_with_schemaif use_schema_prompt else prompts.AnswerWithRAGContextNamesPrompt.system_prompt)response_format = prompts.AnswerWithRAGContextNamesPrompt.AnswerSchemauser_prompt = prompts.AnswerWithRAGContextNamesPrompt.user_promptelif schema == "comparative":system_prompt = (prompts.ComparativeAnswerPrompt.system_prompt_with_schemaif use_schema_prompt else prompts.ComparativeAnswerPrompt.system_prompt)response_format = prompts.ComparativeAnswerPrompt.AnswerSchemauser_prompt = prompts.ComparativeAnswerPrompt.user_promptelse:raise ValueError(f"Unsupported schema: {schema}")return system_prompt, response_format, user_promptdef get_rephrased_questions(self, original_question: str, companies: List[str]) -> Dict[str, str]:"""Use LLM to break down a comparative question into individual questions."""answer_dict = self.processor.send_message(system_content=prompts.RephrasedQuestionsPrompt.system_prompt,human_content=prompts.RephrasedQuestionsPrompt.user_prompt.format(question=original_question,companies=", ".join([f'"{company}"' for company in companies])),is_structured=True,response_format=prompts.RephrasedQuestionsPrompt.RephrasedQuestions)# Convert the answer_dict to the desired formatquestions_dict = {item["company_name"]: item["question"] for item in answer_dict["questions"]}return questions_dictclass AsyncOpenaiProcessor:def _get_unique_filepath(self, base_filepath):"""Helper method to get unique filepath"""if not os.path.exists(base_filepath):return base_filepathbase, ext = os.path.splitext(base_filepath)counter = 1while os.path.exists(f"{base}_{counter}{ext}"):counter += 1return f"{base}_{counter}{ext}"async def process_structured_ouputs_requests(self,model="gpt-4o-mini-2024-07-18",temperature=0.5,seed=None,system_content="You are a helpful assistant.",queries=None,response_format=None,requests_filepath='./temp_async_llm_requests.jsonl',save_filepath='./temp_async_llm_results.jsonl',preserve_requests=False,preserve_results=True,request_url="https://api.openai.com/v1/chat/completions",max_requests_per_minute=3_500,max_tokens_per_minute=3_500_000,token_encoding_name="o200k_base",max_attempts=5,logging_level=20,progress_callback=None):# Create requests for jsonljsonl_requests = []for idx, query in enumerate(queries):request = {"model": model,"temperature": temperature,"seed": seed,"messages": [{"role": "system", "content": system_content},{"role": "user", "content": query},],'response_format': type_to_response_format_param(response_format),'metadata': {'original_index': idx}}jsonl_requests.append(request)# Get unique filepaths if files already existrequests_filepath = self._get_unique_filepath(requests_filepath)save_filepath = self._get_unique_filepath(save_filepath)# Write requests to JSONL filewith open(requests_filepath, "w") as f:for request in jsonl_requests:json_string = json.dumps(request)f.write(json_string + "\n")# Process API requeststotal_requests = len(jsonl_requests)async def monitor_progress():last_count = 0while True:try:with open(save_filepath, 'r') as f:current_count = sum(1 for _ in f)if current_count > last_count:if progress_callback:for _ in range(current_count - last_count):progress_callback()last_count = current_countif current_count >= total_requests:breakexcept FileNotFoundError:passawait asyncio.sleep(0.1)async def process_with_progress():await asyncio.gather(process_api_requests_from_file(requests_filepath=requests_filepath,save_filepath=save_filepath,request_url=request_url,api_key=os.getenv("OPENAI_API_KEY"),max_requests_per_minute=max_requests_per_minute,max_tokens_per_minute=max_tokens_per_minute,token_encoding_name=token_encoding_name,max_attempts=max_attempts,logging_level=logging_level),monitor_progress())await process_with_progress()with open(save_filepath, "r") as f:validated_data_list = []results = []for line_number, line in enumerate(f, start=1):raw_line = line.strip()try:result = json.loads(raw_line)except json.JSONDecodeError as e:print(f"[ERROR] Line {line_number}: Failed to load JSON from line: {raw_line}")continue# Check finish_reason in the API responsefinish_reason = result[1]['choices'][0].get('finish_reason', '')if finish_reason != "stop":print(f"[WARNING] Line {line_number}: finish_reason is '{finish_reason}' (expected 'stop').")# Safely parse answer; if it fails, leave answer empty and report the error.try:answer_content = result[1]['choices'][0]['message']['content']answer_parsed = json.loads(answer_content)answer = response_format(**answer_parsed).model_dump()except Exception as e:print(f"[ERROR] Line {line_number}: Failed to parse answer JSON. Error: {e}.")answer = ""results.append({'index': result[2],'question': result[0]['messages'],'answer': answer})# Sort by original index and build final listvalidated_data_list = [{'question': r['question'], 'answer': r['answer']} for r in sorted(results, key=lambda x: x['index']['original_index'])]if not preserve_requests:os.remove(requests_filepath)if not preserve_results:os.remove(save_filepath)else: # Fix requests orderwith open(save_filepath, "r") as f:results = [json.loads(line) for line in f]sorted_results = sorted(results, key=lambda x: x[2]['original_index'])with open(save_filepath, "w") as f:for result in sorted_results:json_string = json.dumps(result)f.write(json_string + "\n")return validated_data_list# DashScope基础处理器,支持Qwen大模型对话
class BaseDashscopeProcessor:def __init__(self):# 从环境变量读取API-KEYdashscope.api_key = os.getenv("DASHSCOPE_API_KEY")self.default_model = 'qwen-turbo-latest'def send_message(self,model="qwen-turbo-latest",temperature=0.1,seed=None, # 兼容参数,暂不使用system_content='You are a helpful assistant.',human_content='Hello!',is_structured=False,response_format=None,**kwargs):"""发送消息到DashScope Qwen大模型,支持 system_content + human_content 拼接为 messages。暂不支持结构化输出。"""if model is None:model = self.default_model# 拼接 messagesmessages = []if system_content:messages.append({"role": "system", "content": system_content})if human_content:messages.append({"role": "user", "content": human_content})#print('system_content=', system_content)#print('='*30)#print('human_content=', human_content)#print('='*30)#print('messages=', messages)#print('='*30)# 调用 dashscope Generation.callresponse = dashscope.Generation.call(model=model,messages=messages,temperature=temperature,result_format='message')print('dashscope.api_key=', dashscope.api_key)print('model=', model)print('response=', response)# 兼容 openai/gemini 返回格式,始终返回 dictif hasattr(response, 'output') and hasattr(response.output, 'choices'):content = response.output.choices[0].message.contentelse:content = str(response)# 增加 response_data 属性,保证接口一致性self.response_data = {"model": model, "input_tokens": None, "output_tokens": None}print('content=', content)# 始终返回 dict,避免下游 AttributeErrorreturn {"final_answer": content}