当前位置: 首页 > news >正文

基于SQL数据库的智能问答系统设计与实现

一、食用指南

使 LLM 系统能够查询结构化数据在性质上可能与非结构化文本数据有所不同。非结构化文本数据通常生成可针对向量数据库进行搜索的文本,而结构化数据的方法通常是让 LLM 编写并执行 DSL(如SQL)查询。

本文将介绍在数据库中创建表格数据问答系统的基本方法,我们可以向该系统提出关于数据库数据的问题,最终获得自然语言答案。我们将使用 Chain 和 Agent 两种不同的实现方式,两者之间的主要区别在于,Agent 可以根据需要多次循环查询数据库以回答问题。

二、信息安全

构建 SQL 数据库的问答系统需要执行模型生成的 SQL 查询,这样做存在固有风险,请确保数据库连接权限始终尽可能地窄,以满足链/代理的需求,但注意,这只能减轻而不能完全消除构建模型驱动系统的风险。

三、架构说明

  • 将问题转换为 SQL 查询:模型将用户输入转换为 SQL 查询。
  • 执行 SQL 查询:执行查询。
  • 回答问题:模型使用查询结果响应用户输入。

在这里插入图片描述

四、安装依赖

%pip install --upgrade --quiet langchain-community langgraph

五、示例数据

使用 SQLite 连接到 Chinook 数据库,这是一个代表数字媒体商店的示例数据库。

# 下载并构建数据库
curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db
$ sqlite3 Chinook.db 
SQLite version 3.45.3 2024-04-15 13:34:05
Enter ".help" for usage hints.
sqlite> .read Chinook_Sqlite.sql
sqlite> .quit

在这里插入图片描述

使用 SQLAlchemy 驱动的 SQLDatabase 类进行交互:

from langchain_community.utilities import SQLDatabasedb = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

至此,我们有一个可以查询的 SQL 数据库,现在尝试将其连接到 LLM。

六、Chain

链是可预测步骤的组合,在 LangGraph 中可以通过简单的节点序列来表示一个链。这里创建一系列步骤,给定一个问题,执行以下操作:

  • 将问题转换为SQL查询;
  • 执行查询;
  • 使用结果回答原始问题。

这种安排不支持某些场景,例如,该系统将对任何用户输入(甚至是“你好”)执行 SQL 查询,更重要的是,有些问题需要多次查询才能回答,我们将在 Agent 部分讨论这些场景。

1、应用程序状态

LangGraph 应用程序的状态控制着哪些数据作为输入进入应用程序、在步骤之间传输以及由应用程序输出,它通常是 TypedDict,但也可以是 Pydantic BaseModel。

对于此应用程序,我们只需跟踪输入问题、生成的查询、查询结果和生成的答案。

from typing_extensions import TypedDictclass State(TypedDict):question: strquery: strresult: stranswer: str

2、将问题转换为SQL查询

第一步是获取用户输入并将其转换为 SQL 查询,为了可靠地获取 SQL 查询(没有markdown格式和解释或澄清),我们将利用 LangChain 的结构化输出抽象。在此之前,需要为应用程序选择一个聊天模型:

%pip install langchain-openai
import osos.environ["OPENAI_BASE_URL"] = "https://api.siliconflow.cn/v1/"
os.environ["OPENAI_API_KEY"] = "sk-xxx"from langchain.chat_models import init_chat_model# llm = init_chat_model("Qwen/Qwen3-8B", model_provider="openai")
llm = init_chat_model("THUDM/GLM-Z1-9B-0414", model_provider="openai")
# llm = init_chat_model("deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", model_provider="openai")

这里使用硅基流动的免费模型服务,以上代码中使用 THUDM/GLM-Z1-9B-0414 模型,当然也可以使用其他免费模型(deepseek-ai/DeepSeek-R1-0528-Qwen3-8B或者Qwen/Qwen3-8B),直接复制官网上的模型名称即可,点击这里直达官网,注册完成后创建一个 API 密钥就能使用模型了。

在这里插入图片描述
在这里插入图片描述

现在为模型提供一些指令:

from langchain_core.prompts import ChatPromptTemplatesystem_message = """
当收到输入问题时,需构建符合{dialect}语法规范的查询语句以获取答案。
除非用户明确指定需要获取的示例数量,否则始终将查询结果限制在最多{top_k}条。
可通过相关列对结果排序以返回数据库中最具价值的数据。
切勿查询特定表的所有列,仅需获取问题相关的少数几列。请注意:
1. 仅使用模式描述中可见的列名
2. 避免查询不存在的列
3. 明确各列所属的表
可用的表仅限于:{table_info}
"""user_prompt = "Question: {input}"query_prompt_template = ChatPromptTemplate([("system", system_message), ("user", user_prompt)]
)for message in query_prompt_template.messages:message.pretty_print()
================================ System Message ================================当收到输入问题时,需构建符合{dialect}语法规范的查询语句以获取答案。
除非用户明确指定需要获取的示例数量,否则始终将查询结果限制在最多{top_k}条。
可通过相关列对结果排序以返回数据库中最具价值的数据。
切勿查询特定表的所有列,仅需获取问题相关的少数几列。请注意:
1. 仅使用模式描述中可见的列名
2. 避免查询不存在的列
3. 明确各列所属的表
可用的表仅限于:{table_info}================================ Human Message =================================Question: {input}

提示词包含了我们需要填充的几个参数,write_query 将填充这些参数并提示模型生成SQL查询:

from typing_extensions import Annotatedclass QueryOutput(TypedDict):"""Generated SQL query."""query: Annotated[str, ..., "Syntactically valid SQL query."]def write_query(state: State):"""Generate SQL query to fetch information."""prompt = query_prompt_template.invoke({"dialect": db.dialect,"top_k": 10,"table_info": db.get_table_info(),"input": state["question"],})structured_llm = llm.with_structured_output(QueryOutput)result = structured_llm.invoke(prompt)return {"query": result["query"]}

测试:

write_query({"question": "有多少员工?"})
{'query': 'SELECT COUNT(*) AS TotalEmployees FROM Employee'}

这里需要注意一点,llm.with_structured_output 的前提是大模型本身支持结构化输出,并非所有模型都支持 .with_structured_output(),这里不展开具体介绍,若感兴趣可以查阅:告别杂乱文本:模型响应结构化处理技巧大全。

3、执行查询

这是创建 SQL 链最危险的部分,需要仔细考虑:是否可以在数据库上运行自动化查询?是否已经最小化数据库连接权限?

为了执行查询,我们将从 langchain-community 加载一个工具,用 execute_query 节点封装此工具:

from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTooldef execute_query(state: State):"""Execute SQL query."""execute_query_tool = QuerySQLDatabaseTool(db=db)return {"result": execute_query_tool.invoke(state["query"])}

测试:

execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"})
{'result': '[(8,)]'}

4、生成答案

我们的最后一步会根据从数据库中提取的信息生成问题的答案:

def generate_answer(state: State):"""Answer question using retrieved information as context."""prompt = ("根据以下给定的用户问题、相应的 SQL 查询和 SQL 结果,回答用户问题。\n\n"f"Question: {state['question']}\n"f"SQL Query: {state['query']}\n"f"SQL Result: {state['result']}")response = llm.invoke(prompt)return {"answer": response.content}

5、使用 LangGraph 进行编排

将应用程序编译成一个单独的 graph 对象,在这种情况下,我们只是将这三个步骤连接成一个单一的序列:

from langgraph.graph import START, StateGraphgraph_builder = StateGraph(State).add_sequence([write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

LangGraph 还内置了用于可视化应用程序控制流的实用程序:

from IPython.display import Image, displaydisplay(Image(graph.get_graph().draw_mermaid_png()))

在这里插入图片描述

OK,现在可以测试一下 graph,用流式传输单个步骤的结果:

for step in graph.stream({"question": "有多少员工?"}, stream_mode="updates"
):print(step)
{'write_query': {'query': 'SELECT COUNT(*) AS TotalEmployees FROM Employee;'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': '\n根据查询结果,目前共有 **8名员工**。'}}

6、人工审查

LangGraph 支持许多对此工作流有用的功能,其中之一是人工干预:我们可以在敏感步骤(例如SQL查询的执行)之前中断应用程序以进行人工审查,这需要借助 LangGraph 的持久化层实现,它将运行进度保存到某个存储中,下面,我们指定使用内存存储。

若需使用数据库存储可参考:Postgres Checkpointer 实战:优化数据持久化性能的策略

from langgraph.checkpoint.memory import MemorySavermemory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}
display(Image(graph.get_graph().draw_mermaid_png()))

在这里插入图片描述

在这里添加一个简单的“是/否”审批步骤:

for step in graph.stream({"question": "有多少员工?"},config,stream_mode="updates",
):print(step)try:user_approval = input("是否执行查询? (yes/no): ")
except Exception:user_approval = "no"if user_approval.lower() == "yes":# If approved, continue the graph executionfor step in graph.stream(None, config, stream_mode="updates"):print(step)
else:print("Operation cancelled by user.")

在这里插入图片描述

{'write_query': {'query': 'SELECT COUNT(*) FROM Employee;'}}
{'__interrupt__': ()}
是否执行查询? (yes/no):  yes
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': '\n根据提供的SQL查询结果,There are 8 employees in the company.\n\nExplanation:\nThe SQL查询 "SELECT COUNT(*) FROM Employee;" calculates the number of rows in the Employee table by counting all (*) columns.\nThe result [(8,)] shows that there are 8 records in the Employee table.\nTherefore, the answer to "有多少员工?" (How many employees are there?) is 8.'}}

七、Agent

Agent 利用 LLM 的推理能力在执行期间做出决策,使用 Agent 可以将查询生成和执行过程中的额外判断权交给它,尽管它的行为不如上述 Chain 可预测,但它具有一些优势:

  • 可以根据需要多次查询数据库以回答用户问题。
  • 可以通过运行生成的查询、捕获回溯并正确地重新生成来从错误中恢复。
  • 可以根据数据库的模式以及数据库的内容(例如描述特定表)回答问题。

下面我们组装一个最小化的 SQL Agent,使用 LangChain 的 SQLDatabaseToolkit 为其配备一套工具,使用 LangGraph 的预构建 ReAct Agent 构造器,我们可以在一行代码中完成此操作:

from langchain_community.agent_toolkits import SQLDatabaseToolkittoolkit = SQLDatabaseToolkit(db=db, llm=llm)tools = toolkit.get_tools()tools
[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7dbd258261a0>),InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7dbd258261a0>),ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7dbd258261a0>),QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7dbd258261a0>, llm=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x7dbd25827610>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x7dbd25826e60>, root_client=<openai.OpenAI object at 0x7dbd1ea02020>, root_async_client=<openai.AsyncOpenAI object at 0x7dbd258258a0>, model_name='THUDM/GLM-Z1-9B-0414', model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x7dbd25827610>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x7dbd25826e60>, root_client=<openai.OpenAI object at 0x7dbd1ea02020>, root_async_client=<openai.AsyncOpenAI object at 0x7dbd258258a0>, model_name='THUDM/GLM-Z1-9B-0414', model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]

1、系统提示

system_message = """
你是一个与SQL数据库交互的智能代理。当接收到用户输入的问题时,需生成符合语法规范的{dialect}查询语句并执行,随后根据查询结果返回答案。若用户未指定具体示例数量,默认最多返回{top_k}条结果。可通过相关字段排序以获取数据库中更具代表性的数据。核心规则:‌1. 精准查询‌:仅选择与问题直接相关的字段,避免全表扫描
2. 安全验证‌:执行前必须双重检查SQL语句有效性
3. 只读原则‌:严禁执行任何DML操作(INSERT/UPDATE/DELETE等)
4. 初始化流程‌:
- 优先查看数据库表结构
- 分析相关表的元数据信息
5. 容错机制‌:若执行报错需立即重构查询语句重试
""".format(dialect="SQLite",top_k=5,
)

2、初始化

from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agentagent_executor = create_react_agent(llm, tools, prompt=system_message)

注意,这里 llm 使用了 Qwen3 8B 而不是 GLM-Z1-9B-0414,因为分析过程中需要更强的推理能力。

3、测试

请注意,agent 会执行多次查询,直到它获取到所需的信息:

  • 列出可用表;
  • 检索三个表的模式;
  • 通过联接操作查询多个表。

agent 随后能够使用最终查询的结果生成对原始问题的答案。

question = "哪个国家的客户消费最多?"for step in agent_executor.stream({"messages": [{"role": "user", "content": question}]},stream_mode="values",
):step["messages"][-1].pretty_print()
================================ Human Message =================================哪个国家的客户消费最多?
================================== Ai Message ==================================
Tool Calls:sql_db_schema (0198efb1d5b91082f776ea80d84f9322)Call ID: 0198efb1d5b91082f776ea80d84f9322Args:table_names: customers,orders
================================= Tool Message =================================
Name: sql_db_schemaError: table_names {'orders', 'customers'} not found in database
================================== Ai Message ==================================
Tool Calls:sql_db_list_tables (0198efb21e45a991d3866c657a745a56)Call ID: 0198efb21e45a991d3866c657a745a56Args:tool_input:
================================= Tool Message =================================
Name: sql_db_list_tablesAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:sql_db_schema (0198efb26c13369c5d500be1e7a3c643)Call ID: 0198efb26c13369c5d500be1e7a3c643Args:table_names: Customer,Invoice,InvoiceLine
================================= Tool Message =================================
Name: sql_db_schemaCREATE TABLE "Customer" ("CustomerId" INTEGER NOT NULL, "FirstName" NVARCHAR(40) NOT NULL, "LastName" NVARCHAR(20) NOT NULL, "Company" NVARCHAR(80), "Address" NVARCHAR(70), "City" NVARCHAR(40), "State" NVARCHAR(40), "Country" NVARCHAR(40), "PostalCode" NVARCHAR(10), "Phone" NVARCHAR(24), "Fax" NVARCHAR(24), "Email" NVARCHAR(60) NOT NULL, "SupportRepId" INTEGER, PRIMARY KEY ("CustomerId"), FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	luisg@embraer.com.br	3
2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	leonekohler@surfeu.de	5
3	François	Tremblay	None	1498 rue Bélanger	Montréal	QC	Canada	H2G 1A7	+1 (514) 721-4711	None	ftremblay@gmail.com	3
*/CREATE TABLE "Invoice" ("InvoiceId" INTEGER NOT NULL, "CustomerId" INTEGER NOT NULL, "InvoiceDate" DATETIME NOT NULL, "BillingAddress" NVARCHAR(70), "BillingCity" NVARCHAR(40), "BillingState" NVARCHAR(40), "BillingCountry" NVARCHAR(40), "BillingPostalCode" NVARCHAR(10), "Total" NUMERIC(10, 2) NOT NULL, PRIMARY KEY ("InvoiceId"), FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)/*
3 rows from Invoice table:
InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
1	2	2021-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
2	4	2021-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
3	8	2021-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
*/CREATE TABLE "InvoiceLine" ("InvoiceLineId" INTEGER NOT NULL, "InvoiceId" INTEGER NOT NULL, "TrackId" INTEGER NOT NULL, "UnitPrice" NUMERIC(10, 2) NOT NULL, "Quantity" INTEGER NOT NULL, PRIMARY KEY ("InvoiceLineId"), FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)/*
3 rows from InvoiceLine table:
InvoiceLineId	InvoiceId	TrackId	UnitPrice	Quantity
1	1	2	0.99	1
2	1	4	0.99	1
3	2	6	0.99	1
*/
================================== Ai Message ==================================
Tool Calls:sql_db_query_checker (0198efb2ed3ac4e6c5c3981488a937dc)Call ID: 0198efb2ed3ac4e6c5c3981488a937dcArgs:query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 5
================================= Tool Message =================================
Name: sql_db_query_checkerSQL Query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 5
================================== Ai Message ==================================
Tool Calls:sql_db_query (0198efb42e528a2761803c3119ea9fd2)Call ID: 0198efb42e528a2761803c3119ea9fd2Args:query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 5
================================= Tool Message =================================
Name: sql_db_query[('USA', 523.06), ('Canada', 303.96), ('France', 195.1), ('Brazil', 190.1), ('Germany', 156.48)]
================================== Ai Message ==================================消费最多的国家是美国(USA),其总消费金额为523.06美元。以下是各国家的总消费排名:1. **USA** - 523.06美元  
2. **Canada** - 303.96美元  
3. **France** - 195.10美元  
4. **Brazil** - 190.10美元  
5. **Germany** - 156.48美元  数据基于客户所在的国家(Country)统计其所有发票的总消费额(Total),按消费金额降序排列。

在这里插入图片描述

agent 也可以类似地处理定性问题:

question = "描述播放列表的数据表"for step in agent_executor.stream({"messages": [{"role": "user", "content": question}]},stream_mode="values",
):step["messages"][-1].pretty_print()
================================ Human Message =================================描述播放列表的数据表
================================== Ai Message ==================================
Tool Calls:sql_db_list_tables (0198efb4a4d2d9eec1c838cf5a9f9c24)Call ID: 0198efb4a4d2d9eec1c838cf5a9f9c24Args:tool_input:
================================= Tool Message =================================
Name: sql_db_list_tablesAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:sql_db_schema (0198efb4c6a0941b1a7239fb7dc35124)Call ID: 0198efb4c6a0941b1a7239fb7dc35124Args:table_names: Playlist
================================= Tool Message =================================
Name: sql_db_schemaCREATE TABLE "Playlist" ("PlaylistId" INTEGER NOT NULL, "Name" NVARCHAR(120), PRIMARY KEY ("PlaylistId")
)/*
3 rows from Playlist table:
PlaylistId	Name
1	Music
2	Movies
3	TV Shows
*/
================================== Ai Message ==================================播放列表数据表结构如下:**表名**:Playlist  
**字段**:  
- PlaylistId:整数类型,非空,主键  
- Name:可变长度字符串类型(NVARCHAR(120)),用于存储播放列表名称  **示例数据**:  
| PlaylistId | Name        |  
|-------------|-------------|  
| 1           | Music       |  
| 2           | Movies      |  
| 3           | TV Shows    |  该表记录了所有播放列表的基本信息,包含唯一标识符(PlaylistId)和名称(Name)。若需要查看播放列表与曲目的关联数据,可进一步查询 PlaylistTrack 表。

在这里插入图片描述

4、处理高基数列

为了过滤包含专有名词(如地址、歌曲名称或艺术家)的列,我们首先需要仔细检查拼写,以正确过滤数据。我们可以通过创建一个包含数据库中所有不同专有名词的向量存储来实现这一点,然后,每当用户在问题中包含专有名词时,让 agent 查询该向量存储,以找到该词的正确拼写。通过这种方式,agent 可以确保在构建目标查询之前,它理解用户指的是哪个实体。

首先,将结果解析为元素列表:

import ast
import redef query_as_list(db, query):res = db.run(query)res = [el for sub in ast.literal_eval(res) for el in sub if el]res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]return list(set(res))artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]
['Synkronized','Appetite for Destruction','Respighi:Pines of Rome','UB40 The Best Of - Volume Two [UK]','Unplugged']

使用此函数创建一个检索工具,agent 可以自行决定执行该工具。

接下来,选择一个嵌入模型和向量存储,这里使用 bge-m3 和内存向量存储,当然也可以使用其他模型和存储,详情可参考:LangChain × PGVector:手把手教你搭建智能向量数据库。

from langchain_openai import OpenAIEmbeddingsembeddings = OpenAIEmbeddings(model="BAAI/bge-m3",base_url="http://localhost:8000/v1", api_key="EMPTY")

当然,这里也可以不指定 base_url 跟 api_key,这会从环境变量中读取:

os.environ["OPENAI_BASE_URL"] = "https://api.siliconflow.cn/v1/"
os.environ["OPENAI_API_KEY"] = "sk-xxx"

以上我们已经定义过这两个环境变量,但是考虑到硅基流动的向量模型有 batch size 限制,所以这里使用 vLLM 部署了一个本地的模型。

input batch size 604 > maximum allowed batch size 64
from langchain_core.vectorstores import InMemoryVectorStorevector_store = InMemoryVectorStore(embeddings)

然后,构建一个检索工具,用于在数据库中搜索相关的专有名词:

from langchain.agents.agent_toolkits import create_retriever_tool_ = vector_store.add_texts(artists + albums)
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
description = ("用于查找要过滤的值。输入是专有名词的近似拼写,输出是有效的专有名词。使用与搜索最相似的名词。"
)
retriever_tool = create_retriever_tool(retriever,name="search_proper_nouns",description=description,
)

测试:

print(retriever_tool.invoke("Alice Chains"))
Alice In ChainsThe Office, SeasonThe PoliceVs.Otto

在这里插入图片描述

最后,我们构建并测试这个 agent:

# Add to system message
suffix = ("当需要按专有名词(如人名)进行筛选时,必须首先使用 'search_proper_nouns' 工具查找过滤值!切勿尝试猜测专有名词——应通过该功能查找相似名称。"
)system = f"{system_message}\n\n{suffix}"tools.append(retriever_tool)agent = create_react_agent(llm, tools, prompt=system)
question = "Alice In Chains有多少张专辑?"for step in agent.stream({"messages": [{"role": "user", "content": question}]},stream_mode="values",
):step["messages"][-1].pretty_print()
================================ Human Message =================================Alice In Chains有多少张专辑?
================================== Ai Message ==================================
Tool Calls:sql_db_list_tables (0198efe1b559303496107aba7bde5d86)Call ID: 0198efe1b559303496107aba7bde5d86Args:tool_input:
================================= Tool Message =================================
Name: sql_db_list_tablesAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:sql_db_schema (0198efe20fa61f90a2322e571bf7f136)Call ID: 0198efe20fa61f90a2322e571bf7f136Args:table_names: Artist, Album
================================= Tool Message =================================
Name: sql_db_schemaCREATE TABLE "Album" ("AlbumId" INTEGER NOT NULL, "Title" NVARCHAR(160) NOT NULL, "ArtistId" INTEGER NOT NULL, PRIMARY KEY ("AlbumId"), FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/CREATE TABLE "Artist" ("ArtistId" INTEGER NOT NULL, "Name" NVARCHAR(120), PRIMARY KEY ("ArtistId")
)/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/
================================== Ai Message ==================================
Tool Calls:search_proper_nouns (0198efe26e374d2e2724eeca95902518)Call ID: 0198efe26e374d2e2724eeca95902518Args:query: Alice In Chains
================================= Tool Message =================================
Name: search_proper_nounsAlice In ChainsThe Office, SeasonBattlestar Galactica, SeasonLOST, SeasonVs.
================================== Ai Message ==================================
Tool Calls:sql_db_query_checker (0198efe2cf658f95aeb045fe4f809ea0)Call ID: 0198efe2cf658f95aeb045fe4f809ea0Args:query: SELECT COUNT(*) FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains';
================================= Tool Message =================================
Name: sql_db_query_checkerSELECT COUNT(*) FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains';
================================== Ai Message ==================================
Tool Calls:sql_db_query (0198efe37087cb837c37661fc7047f6f)Call ID: 0198efe37087cb837c37661fc7047f6fArgs:query: SELECT COUNT(*) FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains';
================================= Tool Message =================================
Name: sql_db_query[(1,)]
================================== Ai Message ==================================Alice In Chains共有1张专辑。

在这里插入图片描述

参考文档

  • https://python.langchain.ac.cn/docs/tutorials/sql_qa/
http://www.dtcms.com/a/355922.html

相关文章:

  • Codeforces Round 1043 (Div. 3) D. From 1 to Infinity
  • 2025年9月计算机二级C++语言程序设计——选择题打卡Day9
  • 【数据分享】珠江三角洲水系地理空间全套数据集
  • x64dbg的基本调试操作 (未完,待补充)
  • 通信协议再升级,PROFINET和EtherNet IP网关迎接改造升级大挑战
  • 智慧清洁革新者:有鹿机器人自述
  • @Jenkins 介绍、部署与使用标准作业程序
  • 深入 OpenHarmony 内核:设备待机管理模块的休眠调度与资源节能技术
  • AT_abc407_f [ABC407F] Sums of Sliding Window Maximum
  • 告别低效!三坐标测量机提高油缸导向套检测效率
  • 拷贝构造和赋值重载有什么区别
  • 转发、重定向
  • 什么是强化学习? ——— 帮助新手了解
  • 基于51单片机的远程wifi浇花系统设计
  • Snagit 2025.3.0 截图贴图录像编辑
  • Android Keystore签名文件详解与安全防护
  • shell编程学习
  • 基于深度学习的档案级图像修复:Coderformer AI技术解析与应用实践
  • 一、晶振与布局布线处理
  • Python Imaging Library (PIL) 全面指南:Python Imaging Library (PIL)基础图像处理入门
  • 呼叫中心录音加密与数据隔离技术方案全解析
  • Wagtail 扩展 HomePage 模型(一个简单的 例子)
  • 人工智能-python-深度学习-过拟合与欠拟合:概念、判断与解决方法
  • 鸿蒙Harmony-从零开始构建类似于安卓GreenDao的ORM数据库(三)
  • 详解多智能体架构:以 Open Deep Research 项目为例
  • Android中设置RecyclerView滑动到指定条目位置
  • container向harbor推送镜像报错 x509: certificate signed by unknown authority
  • redis添加超时设置
  • SONiC 之 Testbed(2)Ansible
  • Ansible 角色与 Galaxy 生态:角色复用、集合安装与系统角色配置详解