当前位置: 首页 > news >正文

深度学习系列80:Pike-RAG解析

github地址在https://github.com/microsoft/PIKE-RAG

1. chunking

需要examples/chunking.py和examples/biology/configs/chunking.yml

1.1 chunking.yml文件

先看下配置文件:

  1. 去github上下载book并放在这个文件夹
input_doc_setting:
  doc_dir: data/biology/contents

后面chunk完的结果会放在下面:

output_doc_setting:
  doc_dir: data/biology/chunks
  1. 需要自定义个LLM,参考pikerag/llm_client下的文件,然后修改yml文件中的llm_client部分。
  2. 定义splitter,这里用的是pikerag.document_transformers.pikerag.document_transformers,其实就是RecursiveCharacterTextSplitter再加上一个chunk_summary。其中chunk_summary相关代码在prompts/chunking目录下,其中一个template如下:
chunk_summary_template_Chinese = MessageTemplate(
    template=[
        ("system", "You are a helpful AI assistant good at document summarization."),
        ("user", """
# 原文来源

原文来自 {source} 的政策文档 {filename}。

# 原文

“部分原文”:
{content}

# 任务要求

你的任务是输出以上“部分原文”的总结。

# 输出

只输出内容总结,不要添加其他任何内容。
""".strip())
    ],
    input_variables=["source", "filename", "content"],
)

splitter出来的Document会加上如下信息:chunk_meta.update({"summary": chunk_summary})

1.2 chunking.py

首先加载文件:from pikerag.document_loaders import get_loader,会推断文件类型,支持:

class DocumentType(Enum):
    csv = ["csv"]
    excel = ["xlsx", "xls"]
    markdown = ["md"]
    pdf = ["pdf"]
    text = ["txt"]
    word = ["docx", "doc"]

增加metadata:doc.metadata.update({"filename": doc_name})
分割文件并保存:chunk_docs = self._splitter.transform_documents(docs)

2. qa

需要examples/qa.py和examples/biology/configs/qa.yml

2.1 qa.yml文件

调用模块都写在yml文件里面了:

workflow:
  module_path: pikerag.workflows.qa
  class_name: QaWorkflow
  
qa_protocol:
  module_path: pikerag.prompts.qa
  # available attr_name: multiple_choice_qa_protocol, multiple_choice_qa_with_reference_protocol
  attr_name: multiple_choice_qa_with_reference_protocol
  template_partial:
    knowledge_domain: biological
    
llm_client:
  module_path: pikerag.llm_client
  # available class_name: AzureMetaLlamaClient, AzureOpenAIClient, HFMetaLlamaClient
  class_name: AzureOpenAIClient

retriever:
  module_path: pikerag.knowledge_retrievers
  class_name: QaChunkRetriever

    retrieval_query:
      module_path: pikerag.knowledge_retrievers.query_parsers
      func_name: question_plus_options_as_query
      
    vector_store:
      collection_name: biology_book

      id_document_loading:
        module_path: biology.utils
        func_name: load_ids_and_chunks
        args:
          chunk_file_dir: data/biology/chunks

      embedding_setting:
        args:
          model_name: "FremyCompany/BioLORD-2023"
          model_kwargs:
            device: "cuda:0"
          encode_kwargs:
            normalize_embeddings: True

2.2 模块分析

qa.py的代码非常简单:

    workflow = workflow_class(yaml_config)
    workflow.run()

这里顺着模块来梳理。qa的workflow在pikerag.workflows.qa下。
QaWorkflow类的初始化模块包括初始化logger、testing_suite、agent(包含了protocol、retriever、llm_client)、evaluator。
然后run函数调用_single_thread_run()或者_multiple_threads_run()。这里我们看单线程的版本,简单来说就是循环调用anser函数,那我们首先看下answer函数。

    # 步骤:使用retriver获取相关chunks、将问题和chunks组合成message、调用LLM回答问题
    def answer(self, qa: BaseQaData, question_idx: int) -> dict:
        reference_chunks = self._retriever.retrieve_contents(qa, retrieve_id=f"Q{question_idx:03}")
        messages = self._qa_protocol.process_input(content=qa.question, references=reference_chunks, **qa.as_dict())
        response = self._client.generate_content_with_messages(messages, **self.llm_config)
        output_dict: dict = self._qa_protocol.parse_output(response, **qa.as_dict())
        ……

2.3 retriever部分

核心代码为pikerage/workflows/qa.py中的answer函数.

2.3.1 retriever获取chunks

这里的retriever使用的是chroma_qa_retriever.py中的QaChunkRetriever类(同时使用到了mixins文件夹下的chroma_mixin.py)
代码为:reference_chunks: List[str] = self._retriever.retrieve_contents(qa, retrieve_id=f"Q{question_idx:03}")
retriever初始化用到的embedding函数是FremyCompany/BioLORD-2023。我们可以把模型下载到本地加载,也可以把它修改为本地自定义embedding。

这里调用了biology.utils下的load_ids_and_chunks函数,使用pickle.load()读取chunks,数据读取地址为data/biology/chunks(跟chunking.yml保持一致)。

2.3.2 protocol构造prompt

接下来是使用_qa_protocol.process_input处理问题和chunks,代码为:messages = self._qa_protocol.process_input(content=qa.question, references=reference_chunks, **qa.as_dict())

这里是pikerag.prompts.qa.multiple_choice.py中的MultipleChoiceQaWithReferenceParser函数,使用的template为multiple_choice_qa_with_reference_template。函数的功能就是用’\n’把多选项和references连接起来,和content一起填充到template中:

[
        ("system", "You are an helpful assistant good at {knowledge_domain} knowledge that can help people answer {knowledge_domain} questions."),
        ("user", """
# Task
Your task is to think step by step and then choose the correct option from the given options, the chosen option should be correct and the most suitable one to answer the given question. You can refer to the references provided when thinking and answering. Please note that the references may or may not be relevant to the question. If you don't have sufficient information to determine, randomly choose one option from the given options.

# Output format
The output should strictly follow the format below, do not add any redundant information.

<result>
  <thinking>Your thinking for the given question.</thinking>
  <answer>
    <mask>The chosen option mask. Please note that only one single mask is allowable.</mask>
    <option>The option detail corresponds to the chosen option mask.</option>
  </answer>
</result>

# Question
{content}

# Options
{options_str}

# References
{references_str}

# Thinking and Answer
""".strip()),
    ]

2.3.3 LLM获取response

代码为:response = self._client.generate_content_with_messages(messages, **self.llm_config)
实际调用的是_get_response_with_messages函数。我们以pikerag/llm_client/hf_meta_llama_client.py为例,其代码为:

self._client = AutoModelForCausalLM.from_pretrained(self._model_id, **kwargs)
input_ids = self._tokenizer.apply_chat_template(messages,add_generation_prompt=True, return_tensors="pt",).to(self._client.device)
outputs = self._client.generate(input_ids,pad_token_id=self._tokenizer.eos_token_id,**llm_config,)
response = outputs[0][input_ids.shape[-1]:]

最后再调用下面的代码解析结果:
output_dict: dict = self._qa_protocol.parse_output(response, **qa.as_dict())

相关文章:

  • C#工业上位机实时信号边沿检测实现
  • 【前端】【nuxt】几种在 Nuxt 客户端使用console的方式
  • 基于vue框架的在线考试系统的设计与实现l9385(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。
  • python编写的一个打砖块小游戏
  • Leetcode做题记录----1
  • 数据结构:有序表的插入
  • @EnableDiscoveryClient和@EnableEurekaClient springboot3.x
  • AI数字人| Fay开源项目、UE5数字人、本地大模型
  • 计算机网络----第一章:概论
  • LLM开源大模型汇总(截止2025.03.09)
  • OpenText ETX 助力欧洲之翼航空公司远程工作升级
  • 2路模拟量同步输出卡、任意波形发生器卡—PCIe9100数据采集卡
  • JSON.parse(JSON.stringify())深拷贝不会复制函数
  • 长方形旋转计算最后的外层宽高,单元测试
  • HttpMediaTypeNotAcceptableException报错解决,状态码显示为406
  • 深度学习分类回归(衣帽数据集)
  • 98.在 Vue3 中使用 OpenLayers 根据 Resolution 的不同显示不同的地图
  • 【C++】数据类型
  • LLM中的transformer结构学习(二 完结 Multi-Head Attention、Encoder、Decoder)
  • Python之pyqt5生成计算机前端页面并运行
  • 陕西籍青年作家卜文哲爬山时发生意外离世,终年28岁
  • 英伟达回应在上海设立新办公空间:正租用一个新办公空间,这是在中国持续深耕的努力
  • 让中小学生体验不同职业,上海中高职院校提供超5万个体验名额
  • 小米法务部:犯罪团伙操纵近万账号诋毁小米,该起黑公关案告破
  • “当代阿炳”甘柏林逝世,创办了国内第一所残疾人高等学府
  • 广东一驴友在英德野景点溺亡,家属被爆向21名同伴索赔86万