langchain官网翻译:Build a Question/Answering system over SQL data
文章目录
- Build a Question/Answering system over SQL data
- Security note
- Architecture
- Setup
- Sample data
- Chains
- Application state
- Convert question to SQL query
- Execute query
- Generate answer
- Orchestrating with LangGraph
- Human-in-the-loop
- Next steps
- Agents
- System Prompt
- Initializing agent
- Dealing with high-cardinality columns
本文翻译自langchain官网:https://python.langchain.com/docs/tutorials/sql_qa/
Build a Question/Answering system over SQL data
让大语言模型(LLM)系统查询结构化数据的方法,与处理非结构化文本数据存在本质差异。针对非结构化文本通常采用生成文本再通过向量数据库进行检索的方式;而对于结构化数据,常见做法是让LLM使用领域特定语言(如SQL)编写并执行查询语句。本指南将详细介绍对于数据库的表格式数据上构建问答系统的基本方法,涵盖基于chains和agents两种实现方案。这些系统能够让我们用自然语言提问数据库中的数据,并获取自然语言形式的答案。两者的核心区别在于:智能代理可以根据需要循环查询数据库,直至得出问题答案。
Security note
构建基于SQL数据库的问答系统需要执行模型生成的SQL查询,这一过程存在固有风险。请务必确保数据库连接权限始终按照链/代理的实际需求进行最小化范围限定。虽然这不能完全消除模型驱动系统的风险,但能有效降低潜在威胁。关于通用安全最佳实践的更多信息,点击这里查看。
Architecture
从宏观层面看,这类系统的运行包含三个核心步骤:
- 问题转换成SQL查询:模型将用户输入的自然语言问题转换为SQL查询语句
- 执行SQL查询:在数据库中执行生成的查询语句
- 生成问题答案:模型基于查询结果组织自然语言响应
需要说明的是,对CSV格式数据的查询也可采用类似流程。具体实现方法可参阅我们关于CSV数据问答的专题指南。
Setup
首先,安装必要的软件包并设置环境变量:
%pip install --upgrade --quiet langchain-community langgraph
# 请注释下方代码以选择不在本笔记本中使用LangSmith。此项为可选配置。
if not os.environ.get("LANGSMITH_API_KEY"):os.environ["LANGSMITH_API_KEY"] = getpass.getpass()os.environ["LANGSMITH_TRACING"] = "true"
Sample data
以下示例将使用包含Chinook数据库的SQLite连接。该数据库是一个代表digital media store的示例数据库。请按照以下安装步骤,在本笔记本同一目录下创建Chinook.db文件。您也可以通过命令行下载并构建该数据库:
curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db
现在,Chinook.db 已存在于我们的目录中,我们可以通过 SQLAlchemy 驱动的 SQLDatabase 类与之进行交互:
from langchain_community.utilities import SQLDatabasedb = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"
太好了!我们现在已有一个可供查询的 SQL 数据库。接下来,让我们尝试将其与大型语言模型(LLM)连接起来。
Chains
Chains是由可预测步骤组成的组合。在LangGraph中,我们可以通过简单的节点序列来表示Chain。让我们创建一个步骤序列,在给定问题的情况下执行以下操作:
- 将问题转换为SQL查询语句
- 执行该查询语句
- 利用查询结果回答原始问题
这种配置存在一些不支持的应用场景
。例如,该系统会对任何用户输入(即使是"你好"这样的问候语)都执行SQL查询。更重要的是,正如我们将在下文看到的,某些问题需要执行多次查询才能解答。我们将在"Agents"章节中解决这些场景问题。
Application state
LangGraph state 负责管理输入数据、步骤间传递的数据以及最终输出数据。该状态通常采用TypedDict类型定义,但也可以使用Pydantic BaseModel实现。
针对当前应用,我们只需跟踪以下核心数据:
- 输入问题
- 生成的查询语句
- 查询结果
- 生成的最终答案
from typing_extensions import TypedDict
class State(TypedDict):question: strquery: strresult: stranswer: str
现在,我们只需要编写基于该状态进行操作的函数,并填充其内容。
Convert question to SQL query
第一步,将用户输入转换为SQL查询语句。为了可靠地获取纯SQL查询语句(避免Markdown格式、解释说明等附加内容),我们将利用LangChain的结构化输出抽象功能。
现在让我们为程序选择合适的聊天模型:
pip install -qU "langchain[openai]"
import getpass
import os
if not os.environ.get("OPENAI_API_KEY"):os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")from langchain.chat_models import init_chat_model
llm = init_chat_model("gpt-4o-mini", model_provider="openai")
接下来,为模型提供一些指令说明:
from langchain_core.prompts import ChatPromptTemplatesystem_message = """
根据输入的问题,创建语法正确的{dialect}查询语句以帮助找到答案。
除非用户在问题中明确指定了希望获取的记录数量,否则始终将查询结果限制在最多{top_k}条。
您可以通过相关列对结果进行排序,以返回数据库中最具代表性的样本。禁止查询特定表的所有列,只需根据问题要求返回少数相关列。注意仅使用模式描述中可见的列名,谨慎避免查询不存在的列。同时注意各列所属的表。仅可使用以下表:
{table_info}
"""user_prompt = "问题:{input}"query_prompt_template = ChatPromptTemplate([("system", system_message), ("user", user_prompt)]
)for message in query_prompt_template.messages:message.pretty_print()
该提示模板包含多个需要填充的参数,例如SQL dialect和表结构信息。LangChain的SQLDatabase对象提供了辅助方法来处理这些参数。我们的write_query步骤将填充这些参数并提示模型生成SQL查询语句:
from typing_extensions import Annotatedclass QueryOutput(TypedDict):"""生成的SQL查询语句"""query: Annotated[str, ..., "语法有效的SQL查询语句"]def write_query(state: State):"""生成用于获取信息的SQL查询"""prompt = query_prompt_template.invoke({"dialect": db.dialect, # 数据库方言"top_k": 10, # 结果数量限制"table_info": db.get_table_info(), # 表结构信息"input": state["question"], # 用户问题})structured_llm = llm.with_structured_output(QueryOutput)result = structured_llm.invoke(prompt)return {"query": result["query"]} # 返回生成的查询语句
接下来测试下:
write_query({"question": "How many Employees are there?"}){'query': 'SELECT COUNT(*) as employee_count FROM Employee;'}
Execute query
这是构建SQL任务链中最危险的环节。请审慎评估是否允许在数据上运行自动化查询。应尽可能限制数据库连接权限,并考虑在执行查询前在任务链中添加人工审核步骤(详见下文说明)。
为了执行查询,我们将从langchain-community加载一个工具。我们的execute_query节点将封装该工具:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTooldef execute_query(state: State):"""执行 SQL query."""execute_query_tool = QuerySQLDatabaseTool(db=db)return {"result": execute_query_tool.invoke(state["query"])}
测试下:
execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"}){'result': '[(8,)]'}
Generate answer
最后,最终步骤将根据从数据库提取的信息生成问题的答案:
def generate_answer(state: State):"""基于检索到的信息回答问题"""prompt = ("请根据以下用户问题、对应的SQL查询语句及查询结果,回答用户问题:\n\n"f"问题:{state['question']}\n"f"SQL查询:{state['query']}\n"f"查询结果:{state['result']}")response = llm.invoke(prompt)return {"answer": response.content}
Orchestrating with LangGraph
最后,我们将应用程序编译为统一的图对象。在这个案例中,我们只需将三个步骤连接成单一序列即可。
from langgraph.graph import START, StateGraphgraph_builder = StateGraph(State).add_sequence([write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
让我们测试应用程序吧!请注意,我们可以实时流式查看每个步骤的执行结果:
for step in graph.stream({"question": "How many employees are there?"}, stream_mode="updates"
):print(step){'write_query': {'query': 'SELECT COUNT(*) as employee_count FROM Employee;'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are 8 employees in total.'}}
Human-in-the-loop
LangGraph支持许多对该工作流有用的功能,其中一项是"人在回路"机制:我们可以在执行敏感操作(例如运行SQL查询)前中断应用程序,进行人工审核。这通过LangGraph的持久化层实现,该层会将运行进度保存至您选择的存储中。下面我们指定使用内存存储:
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])# 由于使用了持久化机制,我们需要指定线程ID
# 以便在人工审核后能够继续执行流程
config = {"configurable": {"thread_id": "1"}}display(Image(graph.get_graph().draw_mermaid_png()))
让我们重复相同的运行流程,并添加一个简单的"是/否"审批步骤:
for step in graph.stream({"question": "有多少名员工?"},config,stream_mode="updates",
):print(step)try:user_approval = input("是否继续执行查询?(是/否): ")
except Exception:user_approval = "否"if user_approval.lower() == "是":# 如果获得批准,继续执行图流程for step in graph.stream(None, config, stream_mode="updates"):print(step)
else:print("操作已被用户取消。")
Next steps
针对更复杂的查询生成场景,我们可能需要创建小样本提示模板或添加查询校验步骤。有关此类高级技巧及更多内容,请参阅:
提示词策略:高级提示工程技术指南
查询校验:添加查询验证与错误处理机制
大型数据库:处理大规模数据库的技术方案
Agents
Agents利用大型语言模型的推理能力,在执行过程中自主决策。通过使用Agent,您可以将query generation 和执行过程中的额外判断权交给LLM。虽然Agent的行为比上述"任务链"更难以预测,但它们具有以下优势:
- 能够根据需求多次查询数据库以解答用户问题
- 具备错误恢复能力:执行生成的查询语句→捕获错误信息→重新生成正确查询
- 既能基于数据库内容回答问题,也能解析数据库结构(例如描述特定表)
下面我们将写一个最小化的SQL agent。使用LangChain的SQLDatabaseToolkit为其配备工具集,通过LangGraph预构建的ReAct agent 构造器,仅需一行代码即可实现。
请参阅 LangGraph 的 SQL agent教程,了解更高级的 SQL 代理实现方案。参考:https://github.langchain.ac.cn/langgraph/tutorials/sql/sql-agent/
SQLDatabaseToolkit 工具包包含以下功能:
- 创建并执行查询
- 检查查询语法
- 获取表结构描述
- 其他功能
from langchain_community.agent_toolkits import SQLDatabaseToolkittoolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
tools[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x119315480>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x119317550>, root_client=<openai.OpenAI object at 0x10d5f8df0>, root_async_client=<openai.AsyncOpenAI object at 0x1193154e0>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x119315480>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x119317550>, root_client=<openai.OpenAI object at 0x10d5f8df0>, root_async_client=<openai.AsyncOpenAI object at 0x1193154e0>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]
System Prompt
我们还需要为agent加载系统提示System Prompt,其中包含行为规范指令。请注意下方的提示模板包含多个参数,我们将在后续进行赋值。
system_message = """
您是一个专用于与SQL数据库交互的智能代理。
根据输入问题,请先创建语法正确的{dialect}查询语句并执行,
然后分析查询结果并返回答案。除非用户明确指定要获取的记录数量,
否则始终将查询结果限制在最多{top_k}条以内。您可以通过相关列对结果排序,以返回数据库中最具代表性的数据。
禁止查询特定表的所有列,仅需获取与问题相关的字段。在执行查询前必须仔细检查语句。如果执行时出现错误,
请重新编写查询语句并再次尝试。严禁向数据库执行任何DML语句(INSERT、UPDATE、DELETE、DROP等)。开始时应始终先查看数据库中的表结构以确定可查询内容,
切勿跳过此步骤。随后应查询最相关表的模式信息。
""".format(dialect="SQLite",top_k=5,
)
Initializing agent
接下来使用prebuilt LangGraph agent 来搭建agent。
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agentagent_executor = create_react_agent(llm, tools, prompt=system_message)
请观察agent对以下问题的响应方式:
question = "Which country's customers spent the most?"for step in agent_executor.stream({"messages": [{"role": "user", "content": question}]},stream_mode="values",
):step["messages"][-1].pretty_print()
响应很长:
================================[1m Human Message [0m=================================Which country's customers spent the most?
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_list_tables (call_tFp7HYD6sAAmCShgeqkVZH6Q)Call ID: call_tFp7HYD6sAAmCShgeqkVZH6QArgs:
=================================[1m Tool Message [0m=================================
Name: sql_db_list_tablesAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_schema (call_KJZ1Jx6JazyDdJa0uH1UeiOz)Call ID: call_KJZ1Jx6JazyDdJa0uH1UeiOzArgs:table_names: Customer, Invoice
=================================[1m Tool Message [0m=================================
Name: sql_db_schemaCREATE TABLE "Customer" ("CustomerId" INTEGER NOT NULL, "FirstName" NVARCHAR(40) NOT NULL, "LastName" NVARCHAR(20) NOT NULL, "Company" NVARCHAR(80), "Address" NVARCHAR(70), "City" NVARCHAR(40), "State" NVARCHAR(40), "Country" NVARCHAR(40), "PostalCode" NVARCHAR(10), "Phone" NVARCHAR(24), "Fax" NVARCHAR(24), "Email" NVARCHAR(60) NOT NULL, "SupportRepId" INTEGER, PRIMARY KEY ("CustomerId"), FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/CREATE TABLE "Invoice" ("InvoiceId" INTEGER NOT NULL, "CustomerId" INTEGER NOT NULL, "InvoiceDate" DATETIME NOT NULL, "BillingAddress" NVARCHAR(70), "BillingCity" NVARCHAR(40), "BillingState" NVARCHAR(40), "BillingCountry" NVARCHAR(40), "BillingPostalCode" NVARCHAR(10), "Total" NUMERIC(10, 2) NOT NULL, PRIMARY KEY ("InvoiceId"), FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_query_checker (call_AQuTGbgH63u4gPgyV723yrjX)Call ID: call_AQuTGbgH63u4gPgyV723yrjXArgs:query: SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;
=================================[1m Tool Message [0m=================================
Name: sql_db_query_checker\`\`\`sql
SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;
\`\`\`
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_query (call_B88EwU44nwwpQL5M9nlcemSU)Call ID: call_B88EwU44nwwpQL5M9nlcemSUArgs:query: SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;
=================================[1m Tool Message [0m=================================
Name: sql_db_query[('USA', 523.06)]
==================================[1m Ai Message [0m==================================The country whose customers spent the most is the USA, with a total spending of 523.06.
请注意,该agent会执行多次查询直至获取所需信息:
- 首先列出可用数据表
- 随后获取三个表的schema
- 最后通过join操作对多个表进行查询
该代理同样可以处理定性分析类问题:
question = "描述 playlisttrack 表的结构信息"
for step in agent_executor.stream({"messages": [{"role": "user", "content": question}]},stream_mode="values",
):step["messages"][-1].pretty_print()
输出如下:
================================[1m Human Message [0m=================================Describe the playlisttrack table
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_list_tables (call_fMF8eTmX5TJDJjc3Mhdg52TI)Call ID: call_fMF8eTmX5TJDJjc3Mhdg52TIArgs:
=================================[1m Tool Message [0m=================================
Name: sql_db_list_tablesAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_schema (call_W8Vkk4NEodkAAIg8nexAszUH)Call ID: call_W8Vkk4NEodkAAIg8nexAszUHArgs:table_names: PlaylistTrack
=================================[1m Tool Message [0m=================================
Name: sql_db_schemaCREATE TABLE "PlaylistTrack" ("PlaylistId" INTEGER NOT NULL, "TrackId" INTEGER NOT NULL, PRIMARY KEY ("PlaylistId", "TrackId"), FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/
==================================[1m Ai Message [0m==================================The `PlaylistTrack` table is designed to associate tracks with playlists. It has the following structure:- **PlaylistId**: An integer that serves as a foreign key referencing the `Playlist` table. It is part of the composite primary key.
- **TrackId**: An integer that serves as a foreign key referencing the `Track` table. It is also part of the composite primary key.The primary key for this table is a composite key consisting of both `PlaylistId` and `TrackId`, ensuring that each track can be uniquely associated with a playlist. The table enforces referential integrity by linking to the `Track` and `Playlist` tables through foreign keys.
Dealing with high-cardinality columns
为了准确筛选包含专有名词(如地址、歌曲名或艺术家名)的列,我们首先需要核对拼写以确保数据过滤的准确性。
这可以通过以下方式实现:创建一个包含数据库中所有唯一专有名词的向量存储库。当用户问题中包含专有名词时,agent可随时查询该向量库以确认正确拼写。通过这种方式,agent能在构建目标查询前准确理解用户所指的实体对象。
首先我们需要获取目标实体的唯一值集合,为此我们定义一个,能够将解析结果转换为列表的函数:
import ast
import redef query_as_list(db, query):res = db.run(query)res = [el for sub in ast.literal_eval(res) for el in sub if el]res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]return list(set(res))artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]
输出:
['In Through The Out Door','Transmission','Battlestar Galactica (Classic), Season','A Copland Celebration, Vol. I','Quiet Songs']
通过此函数,我们可以创建一个检索工具a retriever tool,供agent自主调用。
接下来为此步骤选择嵌入模型和向量数据库(embeddings model and vector store):
官网此处有多种选择,我们任选其一。
选择嵌入模型:
pip install -qU langchain-huggingface
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
选择向量数据库:
pip install -qU langchain-core
from langchain_core.vectorstores import InMemoryVectorStore
vector_store = InMemoryVectorStore(embeddings)
现在我们可以构建一个检索工具,用于在数据库中的相关专有名词中进行搜索:
from langchain.agents.agent_toolkits import create_retriever_tool# 将艺术家和专辑名称添加到向量库
_ = vector_store.add_texts(artists + albums)
# 创建检索器,设置返回最相似的5个结果
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
# 工具功能描述
description = ("用于查找需要过滤的数值。输入是专有名词的近似拼写,""输出是有效的专有名词。选择与搜索词最相似的名词。"
)
# 创建检索工具
retriever_tool = create_retriever_tool(retriever,name="search_proper_nouns",description=description,
)
print(retriever_tool.invoke("Alice Chains"))
Alice In Chains
Alanis Morissette
Pearl Jam
Pearl Jam
Audioslave
通过这种方式,当agent确定需要基于类似"Alice Chains"的艺术家名称编写过滤条件时,它可以首先使用检索工具来查看该列的相关取值。
将以上组件整合起来:
# 添加到系统消息
suffix = ("如果需要根据专有名词(如名称)进行过滤,你必须始终先使用""'search_proper_nouns'工具查找过滤值!不要尝试猜测专有名词""的正确写法——使用此功能来查找相似名称。"
)system = f"{system_message}\n\n{suffix}"
tools.append(retriever_tool)agent = create_react_agent(llm, tools, prompt=system)question = "alis in chain 有多少张专辑?"
for step in agent.stream({"messages": [{"role": "user", "content": question}]},stream_mode="values",
):step["messages"][-1].pretty_print()
================================[1m Human Message [0m=================================How many albums does alis in chain have?
==================================[1m Ai Message [0m==================================
Tool Calls:search_proper_nouns (call_8ryjsRPLAr79mM3Qvnq6gTOH)Call ID: call_8ryjsRPLAr79mM3Qvnq6gTOHArgs:query: alis in chain
=================================[1m Tool Message [0m=================================
Name: search_proper_nounsAlice In ChainsAisha DuoXisDa Lama Ao CaosA-Sides
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_list_tables (call_NJjtCpU89MBMplssjn1z0xzq)Call ID: call_NJjtCpU89MBMplssjn1z0xzqArgs:search_proper_nouns (call_1BfrueC9koSIyi4OfMu2Ao8q)Call ID: call_1BfrueC9koSIyi4OfMu2Ao8qArgs:query: Alice In Chains
=================================[1m Tool Message [0m=================================
Name: search_proper_nounsAlice In ChainsPearl JamPearl JamFoo FightersSoundgarden
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_schema (call_Kn09w9jd9swcNzIZ1b5MlKID)Call ID: call_Kn09w9jd9swcNzIZ1b5MlKIDArgs:table_names: Album, Artist
=================================[1m Tool Message [0m=================================
Name: sql_db_schemaCREATE TABLE "Album" ("AlbumId" INTEGER NOT NULL, "Title" NVARCHAR(160) NOT NULL, "ArtistId" INTEGER NOT NULL, PRIMARY KEY ("AlbumId"), FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)/*
3 rows from Album table:
AlbumId Title ArtistId
1 For Those About To Rock We Salute You 1
2 Balls to the Wall 2
3 Restless and Wild 2
*/CREATE TABLE "Artist" ("ArtistId" INTEGER NOT NULL, "Name" NVARCHAR(120), PRIMARY KEY ("ArtistId")
)/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
==================================[1m Ai Message [0m==================================
Tool Calls:sql_db_query (call_WkHRiPcBoGN9bc58MIupRHKP)Call ID: call_WkHRiPcBoGN9bc58MIupRHKPArgs:query: SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')
=================================[1m Tool Message [0m=================================
Name: sql_db_query[(1,)]
==================================[1m Ai Message [0m==================================Alice In Chains has released 1 album in the database.
正如我们所见,无论是在流式输出步骤中还是在LangSmith追踪记录里,agent都使用了search_proper_nouns工具来核实如何正确查询该特定艺术家的数据库信息。