langchain将用户问题转sql查询探索
langchain直接介入mysql,将用户问题转化为对mysql的查询,然后依据查询结果回答用户问题。
1 数据准备
这是先示例chinook数据导入mysql,然后通过langchain连接数据库。
1.1 数据导入
这里采用chinook测试数据集,mysql版本的下载链接如下。
https://github.com/lerocha/chinook-database/blob/master/ChinookDatabase/DataSources/Chinook_MySql.sql
这里假设数据已导入,导入代码参考如下链接
https://blog.csdn.net/liliang199/article/details/153821509
1.2 连接数据库
借助于langchain连接数据库。
from langchain_community.utilities import SQLDatabasedb_user="root"
db_password="mysql"
db_host="localhost"
db_name="Chinook" #
uri = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}"
ldb = SQLDatabase.from_uri(uri)
table_infos = ldb.get_table_info()
table_infos存储的极为Chinook数据库的schema信息。
2 问题转sql查询
langchain的chain和agent均借助于prompt的方式,实现问题到sql的转化,这里先直接示例基于prompt直接转sql的过程,然后示例基于chain转sql的过程。
2.1 prompt提示转化
采用prompt的方式,借助于llm将如下question转化为sql查询。
question="What are the names of employee with BirthDate in January?"
示例代码如下所示
prompt_template = """
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".Use the following format:Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer hereOnly use the following tables:
{table_info}Question: {question}
"""import os
os.environ['OPENAI_API_KEY'] = "sk-xxxx"
os.environ['OPENAI_BASE_URL'] = "https://llm_provider_url"from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAIquestion="What are the names of employee with BirthDate in January?"prompt_template = PromptTemplate(template=prompt_template, input_variables=["table_infos", "question"])prompt = prompt_template.format(question=question, table_info=table_infos)
messages = [HumanMessage(content=prompt)]llm = ChatOpenAI(model="deepseek-r1")
sql_query = llm(messages)print(sql_query.content)
llm输出如下
SQLQuery: SELECT `FirstName`, `LastName` FROM `Employee` WHERE MONTH(`BirthDate`) = 1 LIMIT 5;
SQLResult: [Assuming the query runs against the sample data provided, which doesn't have January birthdays]
FirstName LastName
... (no results)Answer: There are no employees with birthdays in January based on the available data.
2.2 chain转化
通过create_sql_query_chain构建基于大模型llm和数据库ldb的chain,基于chain将问题转化为sql查询,示例代码如下。
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chainquery = "What are the names of employee with BirthDate in January?"llm = ChatOpenAI(model="deepseek-r1")
chain = create_sql_query_chain(llm, ldb)
response = chain.invoke({"question": query})
print(response)
输出如下
SQLQuery: SELECT `FirstName`, `LastName` FROM `Employee` WHERE MONTH(`BirthDate`) = 1 LIMIT 5;
SQLResult
chain调用prompt的prompt如下所示
chain.get_prompts()[0].pretty_print()
prompt输出如下
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer hereOnly use the following tables:
{table_info}Question: {input}
reference
---
基于langchain链的SQL DB知识库系统
https://blog.csdn.net/liliang199/article/details/153208506
基于langgraph agent的SQL DB知识库系统
https://blog.csdn.net/liliang199/article/details/153317678
关系型数据库数据集 - northwind & chinook
https://blog.csdn.net/liliang199/article/details/153821509
