当前位置: 首页 > news >正文

RAG 问题处理系统架构解析:企业级智能问答QuestionsProcessor.py的工程实现

附完整代码

前言

在企业知识库和智能问答系统中,问题处理(Questions Processing)是连接用户查询和知识检索的核心桥梁。本文将深入解析一个获得 RAG 挑战赛冠军的问题处理系统实现,该系统支持单公司查询、多公司比较、并行处理、错误恢复等企业级特性,展示了现代 RAG 系统的完整工程实践。

系统架构概览

该问题处理系统采用了模块化的分层架构:

  1. QuestionsProcessor:核心问题处理器,统筹整个问答流程

  2. APIProcessor:多提供商 API 处理器,支持 OpenAI、IBM、Gemini、DashScope

  3. 检索集成:无缝集成向量检索和混合检索

  4. 并行处理:支持多线程并发和批量处理

  5. 错误恢复:完善的异常处理和断点续传机制

核心组件详解

1. 问题处理器核心类(QuestionsProcessor)

QuestionsProcessor 是系统的核心控制器,负责协调检索、推理、答案生成等各个环节。

class QuestionsProcessor:def __init__(self,vector_db_dir: Union[str, Path] = './vector_dbs',documents_dir: Union[str, Path] = './documents',questions_file_path: Optional[Union[str, Path]] = None,new_challenge_pipeline: bool = False,subset_path: Optional[Union[str, Path]] = None,parent_document_retrieval: bool = False,  # 是否启用父文档检索llm_reranking: bool = False,              # 是否启用LLM重排llm_reranking_sample_size: int = 20,top_n_retrieval: int = 10,parallel_requests: int = 10,api_provider: str = "dashscope", # openaianswering_model: str = "qwen-turbo-latest", # gpt-4o-2024-08-06full_context: bool = False):# 初始化配置参数self.questions = self._load_questions(questions_file_path)self.documents_dir = Path(documents_dir)self.vector_db_dir = Path(vector_db_dir)# 检索策略配置self.return_parent_pages = parent_document_retrievalself.llm_reranking = llm_rerankingself.llm_reranking_sample_size = llm_reranking_sample_sizeself.top_n_retrieval = top_n_retrieval# API和并发配置self.api_provider = api_providerself.answering_model = answering_modelself.parallel_requests = parallel_requestsself.openai_processor = APIProcessor(provider=api_provider)# 线程安全和状态管理self.answer_details = []self.detail_counter = 0self._lock = threading.Lock()

设计亮点

  • 丰富的配置参数,支持多种检索策略

  • 多 API 提供商支持,提高系统可用性

  • 线程安全设计,支持并发处理

  • 灵活的流水线模式切换

2. 单公司问答核心流程

针对单个公司的问答是系统的基础功能:

def get_answer_for_company(self, company_name: str, question: str, schema: str) -> dict:# 根据配置选择检索器if self.llm_reranking:retriever = HybridRetriever(vector_db_dir=self.vector_db_dir,documents_dir=self.documents_dir)else:retriever = VectorRetriever(vector_db_dir=self.vector_db_dir,documents_dir=self.documents_dir)
​# 执行检索if self.full_context:retrieval_results = retriever.retrieve_all(company_name)else:           retrieval_results = retriever.retrieve_by_company_name(company_name=company_name,query=question,llm_reranking_sample_size=self.llm_reranking_sample_size,top_n=self.top_n_retrieval,return_parent_pages=self.return_parent_pages)if not retrieval_results:raise ValueError("No relevant context found")# 格式化检索结果为RAG上下文rag_context = self._format_retrieval_results(retrieval_results)# 调用LLM生成答案answer_dict = self.openai_processor.get_answer_from_rag_context(question=question,rag_context=rag_context,schema=schema,model=self.answering_model)# 后处理:页码校验和引用提取if self.new_challenge_pipeline:pages = answer_dict.get("relevant_pages", [])validated_pages = self._validate_page_references(pages, retrieval_results)answer_dict["relevant_pages"] = validated_pagesanswer_dict["references"] = self._extract_references(validated_pages, company_name)return answer_dict

技术特色

  • 智能检索器选择:根据配置自动选择最优检索策略

  • 灵活的上下文模式:支持全文档和精确检索两种模式

  • 智能页码校验:防止 LLM 幻觉,确保引用准确性

  • 结构化输出:支持多种答案类型(name、number、boolean、names)

3. 检索结果格式化

将检索结果转换为 LLM 可理解的上下文格式:

def _format_retrieval_results(self, retrieval_results) -> str:"""将检索结果格式化为RAG上下文字符串"""if not retrieval_results:return ""context_parts = []for result in retrieval_results:page_number = result['page']text = result['text']context_parts.append(f'Text retrieved from page {page_number}: \n"""\n{text}\n"""')return "\n\n---\n\n".join(context_parts)

格式化策略

  • 清晰的页码标识,便于 LLM 理解和引用

  • 统一的分隔符,提高解析准确性

  • 结构化的文本组织,优化 LLM 理解效果

4. 页码引用校验机制

防止 LLM 产生虚假引用的智能校验系统:

def _validate_page_references(self, claimed_pages: list, retrieval_results: list, min_pages: int = 2, max_pages: int = 8) -> list:"""校验LLM答案中引用的页码是否真实存在于检索结果中。若不足最小页数,则补充检索结果中的top页。"""if claimed_pages is None:claimed_pages = []# 获取实际检索到的页码retrieved_pages = [result['page'] for result in retrieval_results]# 校验声称的页码是否真实存在validated_pages = [page for page in claimed_pages if page in retrieved_pages]# 记录被移除的虚假引用if len(validated_pages) < len(claimed_pages):removed_pages = set(claimed_pages) - set(validated_pages)print(f"Warning: Removed {len(removed_pages)} hallucinated page references: {removed_pages}")# 如果有效页码不足最小要求,自动补充if len(validated_pages) < min_pages and retrieval_results:existing_pages = set(validated_pages)for result in retrieval_results:page = result['page']if page not in existing_pages:validated_pages.append(page)existing_pages.add(page)if len(validated_pages) >= min_pages:break# 限制最大页码数量if len(validated_pages) > max_pages:print(f"Trimming references from {len(validated_pages)} to {max_pages} pages")validated_pages = validated_pages[:max_pages]return validated_pages

校验机制优势

  • 幻觉检测:自动识别和移除 LLM 产生的虚假页码

  • 智能补充:当引用不足时自动补充高质量页码

  • 数量控制:防止引用过多影响答案质量

  • 透明日志:详细记录校验过程,便于调试

5. 多公司比较问答

系统的高级功能,支持复杂的多公司对比分析:

def process_comparative_question(self, question: str, companies: List[str], schema: str) -> dict:"""处理多公司比较类问题:1. 先将比较问题重写为单公司问题2. 并行处理每个公司3. 汇总结果并生成最终比较答案"""# Step 1: 问题重写rephrased_questions = self.openai_processor.get_rephrased_questions(original_question=question,companies=companies)individual_answers = {}aggregated_references = []# Step 2: 并行处理各公司问题def process_company_question(company: str) -> tuple[str, dict]:"""处理单个公司问题的辅助函数"""sub_question = rephrased_questions.get(company)if not sub_question:raise ValueError(f"Could not generate sub-question for company: {company}")answer_dict = self.get_answer_for_company(company_name=company, question=sub_question, schema="number")return company, answer_dict
​# 使用线程池并行处理with concurrent.futures.ThreadPoolExecutor() as executor:future_to_company = {executor.submit(process_company_question, company): company for company in companies}for future in concurrent.futures.as_completed(future_to_company):try:company, answer_dict = future.result()individual_answers[company] = answer_dict# 聚合引用信息company_references = answer_dict.get("references", [])aggregated_references.extend(company_references)except Exception as e:company = future_to_company[future]print(f"Error processing company {company}: {str(e)}")raise# 去重引用unique_refs = {}for ref in aggregated_references:key = (ref.get("pdf_sha1"), ref.get("page_index"))unique_refs[key] = refaggregated_references = list(unique_refs.values())# Step 3: 生成比较答案comparative_answer = self.openai_processor.get_answer_from_rag_context(question=question,rag_context=individual_answers,schema="comparative",model=self.answering_model)comparative_answer["references"] = aggregated_referencesreturn comparative_answer

比较问答特色

  • 智能问题分解:自动将比较问题拆分为单公司问题

  • 并行处理:多线程同时处理各公司,提高效率

  • 结果聚合:智能合并各公司答案和引用信息

  • 去重优化:自动去除重复的引用信息

6. 批量处理与并发控制

支持大规模问题批量处理的高效系统:

def process_questions_list(self, questions_list: List[dict], output_path: str = None, submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = "") -> dict:# 批量处理问题列表,支持并行与断点保存total_questions = len(questions_list)questions_with_index = [{**q, "_question_index": i} for i, q in enumerate(questions_list)]self.answer_details = [None] * total_questions  # 预分配答案详情列表processed_questions = []parallel_threads = self.parallel_requests
​if parallel_threads <= 1:# 单线程顺序处理for question_data in tqdm(questions_with_index, desc="Processing questions"):processed_question = self._process_single_question(question_data)processed_questions.append(processed_question)if output_path:self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)else:# 多线程并行处理with tqdm(total=total_questions, desc="Processing questions") as pbar:for i in range(0, total_questions, parallel_threads):batch = questions_with_index[i : i + parallel_threads]with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_threads) as executor:# executor.map 保证结果顺序与输入一致batch_results = list(executor.map(self._process_single_question, batch))processed_questions.extend(batch_results)if output_path:self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)pbar.update(len(batch_results))statistics = self._calculate_statistics(processed_questions, print_stats = True)return {"questions": processed_questions,"answer_details": self.answer_details,"statistics": statistics}

并发处理优势

  • 灵活并发控制:支持单线程和多线程两种模式

  • 批量处理:按批次处理,平衡效率和资源消耗

  • 进度可视化:实时显示处理进度

  • 断点续传:支持中途保存和恢复处理

7. 智能错误处理与恢复

完善的异常处理机制,确保系统稳定性:

def _handle_processing_error(self, question_text: str, schema: str, err: Exception, question_index: int) -> dict:"""处理问题处理过程中的异常。记录错误详情并返回包含错误信息的字典。"""import tracebackerror_message = str(err)tb = traceback.format_exc()error_ref = f"#/answer_details/{question_index}"error_detail = {"error_traceback": tb,"self": error_ref}# 线程安全的错误记录with self._lock:self.answer_details[question_index] = error_detail# 详细的错误日志print(f"Error encountered processing question: {question_text}")print(f"Error type: {type(err).__name__}")print(f"Error message: {error_message}")print(f"Full traceback:\n{tb}\n")# 返回标准化的错误响应if self.new_challenge_pipeline:return {"question_text": question_text,"kind": schema,"value": None,"references": [],"error": f"{type(err).__name__}: {error_message}","answer_details": {"$ref": error_ref}}else:return {"question": question_text,"schema": schema,"answer": None,"error": f"{type(err).__name__}: {error_message}","answer_details": {"$ref": error_ref},}

错误处理特色

  • 详细错误记录:完整的堆栈跟踪和错误上下文

  • 线程安全:多线程环境下的安全错误处理

  • 标准化响应:统一的错误响应格式

  • 调试友好:丰富的调试信息输出

8. 统计分析与监控

实时的处理统计和性能监控:

def _calculate_statistics(self, processed_questions: List[dict], print_stats: bool = False) -> dict:"""统计处理结果,包括总数、错误数、N/A数、成功数"""total_questions = len(processed_questions)error_count = sum(1 for q in processed_questions if "error" in q)na_count = sum(1 for q in processed_questions if (q.get("value") if "value" in q else q.get("answer")) == "N/A")success_count = total_questions - error_count - na_countif print_stats:print(f"\nFinal Processing Statistics:")print(f"Total questions: {total_questions}")print(f"Errors: {error_count} ({(error_count/total_questions)*100:.1f}%)")print(f"N/A answers: {na_count} ({(na_count/total_questions)*100:.1f}%)")print(f"Successfully answered: {success_count} ({(success_count/total_questions)*100:.1f}%)\n")return {"total_questions": total_questions,"error_count": error_count,"na_count": na_count,"success_count": success_count}

API 处理器架构

多提供商统一接口

class APIProcessor:def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):self.provider = provider.lower()if self.provider == "openai":self.processor = BaseOpenaiProcessor()elif self.provider == "ibm":self.processor = BaseIBMAPIProcessor()elif self.provider == "gemini":self.processor = BaseGeminiProcessor()elif self.provider == "dashscope":self.processor = BaseDashscopeProcessor()def get_answer_from_rag_context(self, question, rag_context, schema, model):system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)answer_dict = self.processor.send_message(model=model,system_content=system_prompt,human_content=user_prompt.format(context=rag_context, question=question),is_structured=True,response_format=response_format)# 兜底处理:确保返回完整的答案结构if 'step_by_step_analysis' not in answer_dict:answer_dict = {"step_by_step_analysis": "","reasoning_summary": "","relevant_pages": [],"final_answer": answer_dict.get("final_answer", "N/A")}return answer_dict

API 处理器优势

  • 统一接口:屏蔽不同提供商的 API 差异

  • 智能适配:根据提供商特性自动调整参数

  • 容错机制:完善的兜底和重试逻辑

  • 扩展性:易于添加新的 API 提供商

实际应用场景

1. 企业财务分析

# 单公司财务查询
processor = QuestionsProcessor(vector_db_dir="./financial_dbs",documents_dir="./financial_docs",llm_reranking=True,api_provider="openai",answering_model="gpt-4o-2024-08-06"
)answer = processor.get_answer_for_company(company_name="Apple Inc.",question="2023年第四季度净利润是多少?",schema="number"
)

2. 多公司对比分析

# 多公司比较查询
comparative_answer = processor.process_comparative_question(question="2023年哪家公司研发投入更高,'Apple Inc.'还是'Microsoft Corporation'?",companies=["Apple Inc.", "Microsoft Corporation"],schema="comparative"
)

3. 批量问题处理

# 大规模批量处理
questions_list = [{"question": "公司CEO是谁?", "schema": "name"},{"question": "2023年总营收是多少?", "schema": "number"},{"question": "是否进行了股票回购?", "schema": "boolean"}
]results = processor.process_questions_list(questions_list=questions_list,output_path="./results.json",submission_file=True,parallel_requests=5
)

性能优化策略

1. 检索优化

  • 智能检索器选择:根据查询类型自动选择最优检索策略

  • 缓存机制:缓存常见查询的检索结果

  • 批量检索:合并相似查询,减少检索次数

2. 并发优化

  • 动态线程池:根据系统负载调整并发数

  • 批量处理:平衡并发度和资源消耗

  • 负载均衡:在多个 API 提供商间分配请求

3. 内存管理

  • 流式处理:大规模数据的流式处理

  • 及时释放:处理完成后及时释放资源

  • 内存监控:实时监控内存使用情况

系统监控与调试

1. 实时监控

# 处理统计监控
statistics = processor._calculate_statistics(processed_questions, print_stats=True)
print(f"成功率: {(statistics['success_count']/statistics['total_questions'])*100:.1f}%")
print(f"错误率: {(statistics['error_count']/statistics['total_questions'])*100:.1f}%")

2. 详细日志

# 启用详细日志
import logging
logging.basicConfig(level=logging.INFO)# 自定义日志记录
def log_processing_details(question, answer, processing_time):logger.info(f"Question: {question}")logger.info(f"Answer: {answer.get('final_answer', 'N/A')}")logger.info(f"Processing time: {processing_time:.2f}s")

3. 错误分析

# 错误统计分析
def analyze_errors(processed_questions):errors = [q for q in processed_questions if "error" in q]error_types = {}for error in errors:error_type = error["error"].split(":")[0]error_types[error_type] = error_types.get(error_type, 0) + 1print("Error Analysis:")for error_type, count in error_types.items():print(f"  {error_type}: {count}")

最佳实践建议

1. 配置优化

# 推荐的生产环境配置
production_config = {"llm_reranking": True,           # 启用重排序提高质量"parent_document_retrieval": True, # 启用父文档检索"top_n_retrieval": 10,           # 适中的检索数量"parallel_requests": 5,          # 避免API限流"api_provider": "openai",        # 稳定的API提供商"answering_model": "gpt-4o-2024-08-06"  # 高质量模型
}

2. 错误处理

# 完善的错误处理策略
def robust_question_processing(processor, question, max_retries=3):for attempt in range(max_retries):try:return processor.process_question(question["question"], question["schema"])except Exception as e:if attempt == max_retries - 1:return {"error": f"Failed after {max_retries} attempts: {str(e)}"}time.sleep(2 ** attempt)  # 指数退避

3. 性能监控

# 性能监控装饰器
import time
from functools import wrapsdef monitor_performance(func):@wraps(func)def wrapper(*args, **kwargs):start_time = time.time()result = func(*args, **kwargs)end_time = time.time()print(f"{func.__name__} took {end_time - start_time:.2f} seconds")return resultreturn wrapper

总结

这个问题处理系统展示了企业级 RAG 系统的完整工程实践:

  1. 模块化架构:清晰的分层设计,易于维护和扩展

  2. 多模态支持:支持多种问题类型和答案格式

  3. 并发处理:高效的多线程并行处理机制

  4. 错误恢复:完善的异常处理和断点续传

  5. 监控调试:丰富的统计信息和调试工具

  6. API 抽象:统一的多提供商 API 接口

  7. 智能校验:防止 LLM 幻觉的页码校验机制

对于构建企业级智能问答系统,这个实现提供了完整的参考架构和最佳实践。通过合理的设计和优化,可以在保证答案质量的同时,实现高效、稳定的大规模问题处理能力。

参考资源

  • OpenAI API 文档

  • DashScope API 文档

  • Python 并发编程最佳实践

  • RAG 系统设计模式


本文基于 RAG-Challenge-2 获奖项目的问题处理模块源码分析,展示了工业级问答系统的完整实现和优化策略。希望对正在构建类似系统的开发者有所帮助。

完整代码

import json
from typing import Union, Dict, List, Optional
import re
from pathlib import Path
from src.retrieval import VectorRetriever, HybridRetriever
from src.api_requests import APIProcessor
from tqdm import tqdm
import pandas as pd
import threading
import concurrent.futuresclass QuestionsProcessor:def __init__(self,vector_db_dir: Union[str, Path] = './vector_dbs',documents_dir: Union[str, Path] = './documents',questions_file_path: Optional[Union[str, Path]] = None,new_challenge_pipeline: bool = False,subset_path: Optional[Union[str, Path]] = None,parent_document_retrieval: bool = False,  # 是否启用父文档检索llm_reranking: bool = False,              # 是否启用LLM重排llm_reranking_sample_size: int = 20,top_n_retrieval: int = 10,parallel_requests: int = 10,api_provider: str = "dashscope", # openaianswering_model: str = "qwen-turbo-latest", # gpt-4o-2024-08-06full_context: bool = False):# 初始化问题处理器,配置检索、模型、并发等参数self.questions = self._load_questions(questions_file_path)self.documents_dir = Path(documents_dir)self.vector_db_dir = Path(vector_db_dir)self.subset_path = Path(subset_path) if subset_path else Noneself.new_challenge_pipeline = new_challenge_pipelineself.return_parent_pages = parent_document_retrievalself.llm_reranking = llm_rerankingself.llm_reranking_sample_size = llm_reranking_sample_sizeself.top_n_retrieval = top_n_retrievalself.answering_model = answering_modelself.parallel_requests = parallel_requestsself.api_provider = api_providerself.openai_processor = APIProcessor(provider=api_provider)self.full_context = full_contextself.answer_details = []self.detail_counter = 0self._lock = threading.Lock()def _load_questions(self, questions_file_path: Optional[Union[str, Path]]) -> List[Dict[str, str]]:# 加载问题文件,返回问题列表if questions_file_path is None:return []with open(questions_file_path, 'r', encoding='utf-8') as file:return json.load(file)def _format_retrieval_results(self, retrieval_results) -> str:"""将检索结果格式化为RAG上下文字符串"""if not retrieval_results:return ""context_parts = []for result in retrieval_results:page_number = result['page']text = result['text']context_parts.append(f'Text retrieved from page {page_number}: \n"""\n{text}\n"""')return "\n\n---\n\n".join(context_parts)def _extract_references(self, pages_list: list, company_name: str) -> list:# 根据公司名和页码列表,提取引用信息if self.subset_path is None:raise ValueError("subset_path is required for new challenge pipeline when processing references.")self.companies_df = pd.read_csv(self.subset_path)# Find the company's SHA1 from the subset CSVmatching_rows = self.companies_df[self.companies_df['company_name'] == company_name]if matching_rows.empty:company_sha1 = ""else:company_sha1 = matching_rows.iloc[0]['sha1']refs = []for page in pages_list:refs.append({"pdf_sha1": company_sha1, "page_index": page})return refsdef _validate_page_references(self, claimed_pages: list, retrieval_results: list, min_pages: int = 2, max_pages: int = 8) -> list:"""校验LLM答案中引用的页码是否真实存在于检索结果中。若不足最小页数,则补充检索结果中的top页。"""if claimed_pages is None:claimed_pages = []retrieved_pages = [result['page'] for result in retrieval_results]validated_pages = [page for page in claimed_pages if page in retrieved_pages]if len(validated_pages) < len(claimed_pages):removed_pages = set(claimed_pages) - set(validated_pages)print(f"Warning: Removed {len(removed_pages)} hallucinated page references: {removed_pages}")if len(validated_pages) < min_pages and retrieval_results:existing_pages = set(validated_pages)for result in retrieval_results:page = result['page']if page not in existing_pages:validated_pages.append(page)existing_pages.add(page)if len(validated_pages) >= min_pages:breakif len(validated_pages) > max_pages:print(f"Trimming references from {len(validated_pages)} to {max_pages} pages")validated_pages = validated_pages[:max_pages]return validated_pagesdef get_answer_for_company(self, company_name: str, question: str, schema: str) -> dict:# 针对单个公司,检索上下文并调用LLM生成答案if self.llm_reranking:retriever = HybridRetriever(vector_db_dir=self.vector_db_dir,documents_dir=self.documents_dir)else:retriever = VectorRetriever(vector_db_dir=self.vector_db_dir,documents_dir=self.documents_dir)if self.full_context:retrieval_results = retriever.retrieve_all(company_name)else:           retrieval_results = retriever.retrieve_by_company_name(company_name=company_name,query=question,llm_reranking_sample_size=self.llm_reranking_sample_size,top_n=self.top_n_retrieval,return_parent_pages=self.return_parent_pages)if not retrieval_results:raise ValueError("No relevant context found")rag_context = self._format_retrieval_results(retrieval_results)answer_dict = self.openai_processor.get_answer_from_rag_context(question=question,rag_context=rag_context,schema=schema,model=self.answering_model)self.response_data = self.openai_processor.response_dataif self.new_challenge_pipeline:pages = answer_dict.get("relevant_pages", [])validated_pages = self._validate_page_references(pages, retrieval_results)answer_dict["relevant_pages"] = validated_pagesanswer_dict["references"] = self._extract_references(validated_pages, company_name)return answer_dictdef _extract_companies_from_subset(self, question_text: str) -> list[str]:"""从问题文本中提取公司名,匹配subset文件中的公司"""if not hasattr(self, 'companies_df'):if self.subset_path is None:raise ValueError("subset_path must be provided to use subset extraction")self.companies_df = pd.read_csv(self.subset_path)found_companies = []company_names = sorted(self.companies_df['company_name'].unique(), key=len, reverse=True)for company in company_names:escaped_company = re.escape(company)pattern = rf'{escaped_company}(?:\W|$)'if re.search(pattern, question_text, re.IGNORECASE):found_companies.append(company)question_text = re.sub(pattern, '', question_text, flags=re.IGNORECASE)return found_companiesdef process_question(self, question: str, schema: str):# 处理单个问题,支持多公司比较if self.new_challenge_pipeline:extracted_companies = self._extract_companies_from_subset(question)else:extracted_companies = re.findall(r'"([^"]*)"', question)if len(extracted_companies) == 0:raise ValueError("No company name found in the question.")if len(extracted_companies) == 1:company_name = extracted_companies[0]answer_dict = self.get_answer_for_company(company_name=company_name, question=question, schema=schema)return answer_dictelse:return self.process_comparative_question(question, extracted_companies, schema)def _create_answer_detail_ref(self, answer_dict: dict, question_index: int) -> str:"""创建答案详情的引用ID,并存储详细内容"""ref_id = f"#/answer_details/{question_index}"with self._lock:self.answer_details[question_index] = {"step_by_step_analysis": answer_dict['step_by_step_analysis'],"reasoning_summary": answer_dict['reasoning_summary'],"relevant_pages": answer_dict['relevant_pages'],"response_data": self.response_data,"self": ref_id}return ref_iddef _calculate_statistics(self, processed_questions: List[dict], print_stats: bool = False) -> dict:"""统计处理结果,包括总数、错误数、N/A数、成功数"""total_questions = len(processed_questions)error_count = sum(1 for q in processed_questions if "error" in q)na_count = sum(1 for q in processed_questions if (q.get("value") if "value" in q else q.get("answer")) == "N/A")success_count = total_questions - error_count - na_countif print_stats:print(f"\nFinal Processing Statistics:")print(f"Total questions: {total_questions}")print(f"Errors: {error_count} ({(error_count/total_questions)*100:.1f}%)")print(f"N/A answers: {na_count} ({(na_count/total_questions)*100:.1f}%)")print(f"Successfully answered: {success_count} ({(success_count/total_questions)*100:.1f}%)\n")return {"total_questions": total_questions,"error_count": error_count,"na_count": na_count,"success_count": success_count}def process_questions_list(self, questions_list: List[dict], output_path: str = None, submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = "") -> dict:# 批量处理问题列表,支持并行与断点保存,返回处理结果和统计信息total_questions = len(questions_list)# 给每个问题加索引,便于后续答案详情定位questions_with_index = [{**q, "_question_index": i} for i, q in enumerate(questions_list)]self.answer_details = [None] * total_questions  # 预分配答案详情列表processed_questions = []parallel_threads = self.parallel_requestsif parallel_threads <= 1:# 单线程顺序处理for question_data in tqdm(questions_with_index, desc="Processing questions"):processed_question = self._process_single_question(question_data)processed_questions.append(processed_question)if output_path:self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)else:# 多线程并行处理with tqdm(total=total_questions, desc="Processing questions") as pbar:for i in range(0, total_questions, parallel_threads):batch = questions_with_index[i : i + parallel_threads]with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_threads) as executor:# executor.map 保证结果顺序与输入一致batch_results = list(executor.map(self._process_single_question, batch))processed_questions.extend(batch_results)if output_path:self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)pbar.update(len(batch_results))statistics = self._calculate_statistics(processed_questions, print_stats = True)return {"questions": processed_questions,"answer_details": self.answer_details,"statistics": statistics}def _process_single_question(self, question_data: dict) -> dict:question_index = question_data.get("_question_index", 0)if self.new_challenge_pipeline:question_text = question_data.get("text")schema = question_data.get("kind")else:question_text = question_data.get("question")schema = question_data.get("schema")try:answer_dict = self.process_question(question_text, schema)if "error" in answer_dict:detail_ref = self._create_answer_detail_ref({"step_by_step_analysis": None,"reasoning_summary": None,"relevant_pages": None}, question_index)if self.new_challenge_pipeline:return {"question_text": question_text,"kind": schema,"value": None,"references": [],"error": answer_dict["error"],"answer_details": {"$ref": detail_ref}}else:return {"question": question_text,"schema": schema,"answer": None,"error": answer_dict["error"],"answer_details": {"$ref": detail_ref},}detail_ref = self._create_answer_detail_ref(answer_dict, question_index)if self.new_challenge_pipeline:return {"question_text": question_text,"kind": schema,"value": answer_dict.get("final_answer"),"references": answer_dict.get("references", []),"answer_details": {"$ref": detail_ref}}else:return {"question": question_text,"schema": schema,"answer": answer_dict.get("final_answer"),"answer_details": {"$ref": detail_ref},}except Exception as err:return self._handle_processing_error(question_text, schema, err, question_index)def _handle_processing_error(self, question_text: str, schema: str, err: Exception, question_index: int) -> dict:"""处理问题处理过程中的异常。记录错误详情并返回包含错误信息的字典。"""import tracebackerror_message = str(err)tb = traceback.format_exc()error_ref = f"#/answer_details/{question_index}"error_detail = {"error_traceback": tb,"self": error_ref}with self._lock:self.answer_details[question_index] = error_detailprint(f"Error encountered processing question: {question_text}")print(f"Error type: {type(err).__name__}")print(f"Error message: {error_message}")print(f"Full traceback:\n{tb}\n")if self.new_challenge_pipeline:return {"question_text": question_text,"kind": schema,"value": None,"references": [],"error": f"{type(err).__name__}: {error_message}","answer_details": {"$ref": error_ref}}else:return {"question": question_text,"schema": schema,"answer": None,"error": f"{type(err).__name__}: {error_message}","answer_details": {"$ref": error_ref},}def _post_process_submission_answers(self, processed_questions: List[dict]) -> List[dict]:"""提交格式后处理:1. 页码从1-based转为0-based2. N/A答案清空引用3. 格式化为比赛提交schema4. 包含step_by_step_analysis"""submission_answers = []for q in processed_questions:question_text = q.get("question_text") or q.get("question")kind = q.get("kind") or q.get("schema")value = "N/A" if "error" in q else (q.get("value") if "value" in q else q.get("answer"))references = q.get("references", [])answer_details_ref = q.get("answer_details", {}).get("$ref", "")step_by_step_analysis = Noneif answer_details_ref and answer_details_ref.startswith("#/answer_details/"):try:index = int(answer_details_ref.split("/")[-1])if 0 <= index < len(self.answer_details) and self.answer_details[index]:step_by_step_analysis = self.answer_details[index].get("step_by_step_analysis")except (ValueError, IndexError):pass# Clear references if value is N/Aif value == "N/A":references = []else:# Convert page indices from one-based to zero-based (competition requires 0-based page indices, but for debugging it is easier to use 1-based)references = [{"pdf_sha1": ref["pdf_sha1"],"page_index": ref["page_index"] - 1}for ref in references]submission_answer = {"question_text": question_text,"kind": kind,"value": value,"references": references,}if step_by_step_analysis:submission_answer["reasoning_process"] = step_by_step_analysissubmission_answers.append(submission_answer)return submission_answersdef _save_progress(self, processed_questions: List[dict], output_path: Optional[str], submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = ""):if output_path:statistics = self._calculate_statistics(processed_questions)# Prepare debug contentresult = {"questions": processed_questions,"answer_details": self.answer_details,"statistics": statistics}output_file = Path(output_path)debug_file = output_file.with_name(output_file.stem + "_debug" + output_file.suffix)with open(debug_file, 'w', encoding='utf-8') as file:json.dump(result, file, ensure_ascii=False, indent=2)if submission_file:# Post-process answers for submissionsubmission_answers = self._post_process_submission_answers(processed_questions)submission = {"answers": submission_answers,"team_email": team_email,"submission_name": submission_name,"details": pipeline_details}with open(output_file, 'w', encoding='utf-8') as file:json.dump(submission, file, ensure_ascii=False, indent=2)def process_all_questions(self, output_path: str = 'questions_with_answers.json', team_email: str = "79250515615@yandex.com", submission_name: str = "Ilia_Ris SO CoT + Parent Document Retrieval", submission_file: bool = False, pipeline_details: str = ""):result = self.process_questions_list(self.questions,output_path,submission_file=submission_file,team_email=team_email,submission_name=submission_name,pipeline_details=pipeline_details)return resultdef process_comparative_question(self, question: str, companies: List[str], schema: str) -> dict:"""处理多公司比较类问题:1. 先将比较问题重写为单公司问题2. 并行处理每个公司3. 汇总结果并生成最终比较答案"""# Step 1: Rephrase the comparative questionrephrased_questions = self.openai_processor.get_rephrased_questions(original_question=question,companies=companies)individual_answers = {}aggregated_references = []# Step 2: Process each individual question in paralleldef process_company_question(company: str) -> tuple[str, dict]:"""Helper function to process one company's question and return (company, answer)"""sub_question = rephrased_questions.get(company)if not sub_question:raise ValueError(f"Could not generate sub-question for company: {company}")answer_dict = self.get_answer_for_company(company_name=company, question=sub_question, schema="number")return company, answer_dictwith concurrent.futures.ThreadPoolExecutor() as executor:future_to_company = {executor.submit(process_company_question, company): company for company in companies}for future in concurrent.futures.as_completed(future_to_company):try:company, answer_dict = future.result()individual_answers[company] = answer_dictcompany_references = answer_dict.get("references", [])aggregated_references.extend(company_references)except Exception as e:company = future_to_company[future]print(f"Error processing company {company}: {str(e)}")raise# Remove duplicate referencesunique_refs = {}for ref in aggregated_references:key = (ref.get("pdf_sha1"), ref.get("page_index"))unique_refs[key] = refaggregated_references = list(unique_refs.values())# Step 3: Get the comparative answer using all individual answerscomparative_answer = self.openai_processor.get_answer_from_rag_context(question=question,rag_context=individual_answers,schema="comparative",model=self.answering_model)self.response_data = self.openai_processor.response_datacomparative_answer["references"] = aggregated_referencesreturn comparative_answer

http://www.dtcms.com/a/486017.html

相关文章:

  • LlamaIndex多模态RAG开发实现详解
  • springboot实现微信小程序支付(服务商和普通商户模式)
  • 石景山网站建设好的公司有特色的企业网站
  • 个人建网站怎么赚钱网站一般用什么数据库
  • 【机器学习03】学习率与特征工程、多项式回归、逻辑回归
  • PyTorch解析使用张量与动态计算图实现深度学习模型的高效训练
  • 大二java学习笔记:二维数组
  • 缓存行Cache Line
  • 10-机器学习与大模型开发数学教程-第1章 1-2 O(n) 表示法与时间复杂度
  • toLua[六] Examples 05_LuaCoroutine分析
  • keil5使用STlink下载程序到stm32后不自动运行的解决办法
  • stm32大项目阶段20251015
  • 机器学习四范式(有监督、无监督、强化学习、半监督学习)
  • 源码分析 golang bigcache 高性能无 GC 开销的缓存设计实现
  • 网站开发的工资开发者应用
  • 东莞网站建设优化企业太平洋保险网站
  • transformer-注意力评分函数
  • 破解 Shuffle 阻塞:Spark RDD 宽窄依赖在实时特征工程中的实战与未来
  • TypeScript入门学习
  • 西固网站建设平台12306网站花多少钱做的
  • Linux运维实战:云原生设计与实施DockerK8S(视频教程)
  • Chroma 开源的 AI 应用搜索与检索数据库(即向量数据库)
  • 楼宇自控 DDC 系统 + IBMS 智能化集成系统:构建建筑智慧运营双核心
  • 《深度学习框架核心之争:PyTorch动态图与早期TensorFlow静态图的底层逻辑与实战对比》
  • 固件下printf函数分析
  • 做外贸都得有网站吗秦皇岛网站排名公司
  • AI-Native 能力反思(三):Prompt Engineering 自我提升神器
  • 基于Django+Vue2+MySQL前后端分离的红色故事分享平台
  • LangGraph 工作流全解析:从 Prompt 到智能体编排的革命
  • JAVA算法练习题day42