一、手动实现数据处理、流程编排
from glob import glob
import os
from openai import OpenAI
from pymilvus import MilvusClient
from tqdm import tqdm
import json
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage# 使用硅基流动的免费embedding模型
openai_client = OpenAI(api_key="***",base_url="https://api.siliconflow.cn/v1",
)
milvus_client = MilvusClient(uri="./milvus_demo.db")
collection_name = "my_rag_collection"# 使用智谱的免费文本生成模型
llm = ChatOpenAI(temperature=0.6,model="glm-4.5",openai_api_key="***",openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
)def emb_long_text(text):chunk_size = 512if len(text) <= chunk_size:return emb_text(text)embeddings = []for i in range(0, len(text), chunk_size):chunk = text[i : i + chunk_size]embedding = emb_text(chunk)embeddings.append(embedding)# Average the embeddings of all chunksavg_embedding = [sum(x) / len(embeddings) for x in zip(*embeddings)]return avg_embeddingdef emb_text(text):return (openai_client.embeddings.create(input=text, model="BAAI/bge-m3").data[0].embedding)def create_data_if_need():if milvus_client.has_collection(collection_name):print(milvus_client.describe_collection(collection_name))returntext_lines = []for file_path in glob(os.path.expanduser("~/Desktop/milvus_docs/**/*.md"), recursive=True):with open(file_path, "r", encoding="utf-8") as file:file_text = file.read()text_lines += file_text.split("# ")embedding_dim = 1024milvus_client.create_collection(collection_name=collection_name,dimension=embedding_dim,metric_type="IP", # Inner product distanceconsistency_level="Bounded",)data = []for i, line in enumerate(tqdm(text_lines, desc="Creating embeddings")):if not line.strip():continuevector = emb_long_text(line)if not vector:print(f"Failed to embed line: {line}")continuedata.append({"id": i, "vector": vector, "text": line, "text_len": len(line)})milvus_client.insert(collection_name=collection_name, data=data)def do_chat(context, question):SYSTEM_PROMPT = """你是一名AI助手,你将根据提供的上下文信息回答用户的问题。如果上下文中没有相关信息,请诚实地告诉用户你不知道答案,而不是编造答案。你必须严格根据上下文信息作答,不能凭空添加任何信息。"""USER_PROMPT = f"""使用下面的context标签中的信息用中文回答用户question标签中的问题。<context>{context}</context><question>{question}</question>"""# 创建消息messages = [SystemMessage(content=SYSTEM_PROMPT),HumanMessage(content=USER_PROMPT),]# 调用模型response = llm.invoke(messages)print(response.content)if __name__ == "__main__":create_data_if_need()while True:question = input("请输入你的问题: ")search_res = milvus_client.search(collection_name=collection_name,data=[emb_text(question)],limit=35,filter="text_len > 500",search_params={"metric_type": "IP", "params": {}},output_fields=["text"],)retrieved_lines_with_distances = [(res["entity"]["text"], res["distance"]) for res in search_res[0]]print(json.dumps(retrieved_lines_with_distances, indent=4))context = "\n".join([line_with_distance[0]for line_with_distance in retrieved_lines_with_distances])do_chat(context, question)
二、通过langchain自动进行数据处理、流程编排
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_milvus import Milvus
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pymilvus import MilvusClient
from glob import glob
from tqdm import tqdm
import os, json# 创建嵌入模型
embeddings = OpenAIEmbeddings(model="BAAI/bge-m3",openai_api_key="***",openai_api_base="https://api.siliconflow.cn/v1",
)
# 创建语言模型
llm = ChatOpenAI(temperature=0.6,model="glm-4.5-flash",openai_api_key="***",openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
)
URI = "./milvus_demo_v2.db"
collection_name = "my_rag_collection_v2"# 加载或创建向量存储
def load_vectorstore():vector_store = Milvus(collection_name=collection_name,embedding_function=embeddings,connection_args={"uri": URI},)milvus_client = MilvusClient(uri=URI)if milvus_client.has_collection(collection_name):print("向量存储已存在,直接加载...")return vector_storeprint("创建新的向量存储...")# 加载文档documents = []for file_path in glob(# 文档下载地址:https://github.com/milvus-io/milvus-docs/releases/download/v2.4.6-preview/milvus_docs_2.4.x_en.zipos.path.expanduser("~/Desktop/milvus_docs/**/*.md"),recursive=True,):try:loader = TextLoader(file_path, encoding="utf-8")documents.extend(loader.load())except Exception as e:print(f"加载文件 {file_path} 时出错: {e}")if not documents:raise ValueError("未找到任何文档!")# milvus-docs已经按文件进行了分割,不再进行文本分割# text_splitter = RecursiveCharacterTextSplitter(# chunk_size=800,# chunk_overlap=100,# length_function=len,# )# splits = text_splitter.split_documents(documents)# 补充文本长度元数据,方便后续检索过滤for doc in documents:doc.metadata["text_len"] = len(doc.page_content)print(f"共 {len(documents)} 个document需要插入")# 步长过长可能会导致达到接口限制stride = 10for i in tqdm(range(0, len(documents), stride), desc="添加文档到向量存储ing..."):sub_splits = documents[i : i + stride]vector_store.add_documents(sub_splits)return vector_storedef format_docs(docs):for doc in docs:print(doc)print("===========================================================================")print("===========================================================================")print("===================向量库检索完成,等待大模型响应ing....===================")print("===========================================================================")print("===========================================================================")# 合并多个关联文档,以提交给大模型return "\n\n".join(doc.page_content for doc in docs)if __name__ == "__main__":# 创建或加载向量存储vector_store = load_vectorstore()while True:query = input("请输入您的问题:")# 定义提示模板prompt_template = """你是一名AI助手,你将根据提供的上下文信息回答用户的问题。
如果上下文中没有相关信息,请诚实地告诉用户你不知道答案,而不是编造答案。
你必须严格根据上下文信息作答,不能凭空添加任何信息。使用下面的context标签中的信息用中文回答用户question标签中的问题。
<context>
{context}
</context>
<question>
{question}
</question>"""prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])retriever = vector_store.as_retriever(# 取top10个关联文档,再通过过滤器筛选search_kwargs=dict(k=10, expr="text_len > 300"))rag_chain = ({"context": retriever | format_docs, "question": RunnablePassthrough()}| prompt| llm| StrOutputParser())final_res = rag_chain.invoke(query)print(final_res)