当前位置: 首页 > news >正文

利用 LangChain 和一个大语言模型(LLM)构建一个链条,自动从用户输入的问题中提取相关的 SQL 表信息,再生成对应的 SQL 查询

示例代码:

from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic

# 系统消息,要求 LLM 返回与问题相关的 SQL 表类别
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""

# 初始化 LLM 模型
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

# 创建提取链:将用户问题转换为 Table 模型的实例
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

# 定义一个函数,根据 Table 对象映射到具体的 SQL 表名
def get_tables(categories: List[Table]) -> List[str]:
    """将类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

# 将类别提取链与映射函数组合,得到一个返回 SQL 表名列表的链
table_chain = category_chain | get_tables 

# 定义自定义 SQL 提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

# 创建 SQL 查询链
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)

# 利用 bind 将固定参数绑定到 SQL 查询链中
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)

# 将输入中的 "question" 键复制到 "input" 键,同时保留原始数据
table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain

# 使用 RunnablePassthrough.assign 将提取到的表名添加到上下文中,然后与 SQL 查询链组合
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain

# 调用整个链,生成 SQL 查询
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
)
print(query)

这段代码主要展示如何利用 LangChain 和一个大语言模型(LLM)构建一个链条,自动从用户输入的问题中提取相关的 SQL 表信息,再生成对应的 SQL 查询。下面我将分步详细解释每个部分的作用,并通过举例说明每段代码的输入和输出。


1. 定义系统消息和初始化 LLM 模型

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
  • 作用:
    这段系统消息告诉 LLM:请根据用户的问题返回与问题相关的 SQL 表类别,这里限定了两类——“Music”和“Business”。

  • 举例:
    如果用户的问题涉及音乐信息(例如歌曲、专辑等),那么 LLM 会返回 “Music”;如果涉及客户、发票等信息,则返回 “Business”。

table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)
  • 作用:
    初始化一个 LLM 模型(此处使用 llama3-70b-8192,由 groq 提供,温度设为 0 以保证回答确定性),后续会用这个模型进行类别提取和 SQL 查询生成。

  • 输出:
    返回一个 LLM 实例 table_extractor_llm


2. 创建提取链:从问题中抽取相关表类别

category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
  • 作用:
    这里利用 create_extraction_chain_pydantic 创建了一个链,该链的任务是将用户输入的问题转换为一个或多个符合 Pydantic 模型 Table 的实例。也就是说,LLM 会分析问题并输出如 Table(name="Music")Table(name="Business") 的结果。

  • 输入:
    用户问题(例如 “What are all the genres of Alanis Morisette songs? Do not repeat!”)。

  • 输出:
    一个或多个 Table 对象,指明问题相关的表类别。例如,对于这个问题,可能返回 [Table(name="Music")]
    在这里插入图片描述


3. 定义映射函数,将类别映射到具体的 SQL 表名

def get_tables(categories: List[Table]) -> List[str]:
    """将类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables
  • 作用:
    此函数接收前面提取链返回的 Table 对象列表,根据类别名称映射到具体的 SQL 表名列表:

    • 如果类别是 “Music”,则映射为音乐相关的多个表(如 Album、Artist、Genre 等)。
    • 如果类别是 “Business”,则映射为商业相关的表(如 Customer、Invoice 等)。
  • 举例:

    • 输入: [Table(name="Music")]
    • 输出: ["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]

4. 组合提取链和映射函数

table_chain = category_chain | get_tables
  • 作用:
    利用管道操作符(|)将 category_chainget_tables 组合起来。整个链条(table_chain)的作用就是:接收用户问题 → 利用 LLM 提取相关类别 → 将类别映射为具体的 SQL 表名列表。

  • 输入:
    一个包含用户问题的字典(例如 {"question": "..."})。

  • 输出:
    一个 SQL 表名列表,如上例中的音乐相关表名列表。
    在这里插入图片描述


5. 定义自定义 SQL 提示模板

custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
  • 作用:
    该模板为生成 SQL 查询提供指令:

    • 指定 SQL 方言(如 MySQL、PostgreSQL 等)。
    • 提供数据库的表结构信息。
    • 告诉 LLM 根据问题({input})生成正确的 SQL 查询。
    • 不要限制返回行数,并且只返回 SQL 语句本身,无额外说明。
  • 举例:
    如果传入:

    • dialect: “SQLite”
    • input: “What are all the genres of Alanis Morisette songs? Do not repeat!”
    • table_info: 数据库所有表的结构信息
    • top_k: 55
      那么模板会指导 LLM 输出类似下面的 SQL 查询(实际内容由 LLM 根据 schema 生成):
    SELECT DISTINCT Genre.Name FROM Track
    JOIN Genre ON Track.GenreId = Genre.GenreId
    JOIN Artist ON Track.ArtistId = Artist.ArtistId
    WHERE Artist.Name = 'Alanis Morisette';
    

6. 创建 SQL 查询链并绑定固定参数

query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)
  • 作用:
    利用同一个 LLM 实例和预定义的 SQL 提示模板,创建一个 SQL 查询链。该链将根据数据库表结构(db)和用户问题生成 SQL 查询。
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)
  • 作用:
    通过 bind 方法将一些固定的参数绑定到 SQL 查询链上:

    • dialect:数据库使用的 SQL 方言。
    • table_info:数据库中所有表的结构信息。
    • top_k:限制返回的行数,这里设定为 55 行,但指令中说明不要限制,所以其实这个参数仅作为提示的一部分。
  • 输出:
    得到一个参数已经固定的 SQL 查询链 bound_chain,后续调用时只需要传入用户问题(以及其他动态数据)。


7. 调整输入数据格式

table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain
  • 作用:
    这行代码先用一个 lambda 函数将输入字典中的 "question" 键复制一份到 "input" 键,目的是统一变量名称(因为上面的 SQL 提示模板要求有 input 变量)。然后再将结果传递给 table_chain

  • 举例:

    • 输入: {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
    • lambda 输出: {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!", "input": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
    • 最终经过 table_chain 输出: 列表形式的 SQL 表名,如 ["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]

8. 组合整个链条,生成最终 SQL 查询

full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain
  • 作用:
    这里使用 RunnablePassthrough.assign 将从 table_chain 得到的 SQL 表名列表赋值到上下文中的 table_names_to_use 键,然后通过管道传递给已经绑定参数的 SQL 查询链 bound_chain。这一步确保了在生成 SQL 查询时,上下文中不仅包含用户的原始问题,还包含了与之相关的 SQL 表名信息。

  • 输入:
    包含用户问题的字典(经过前面的处理已包含 "input" 键)。

  • 输出:
    经过整个链条处理后,输出最终生成的 SQL 查询语句。


9. 调用链条并生成 SQL 查询

query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
)
print(query)
  • 作用:
    这里将包含用户问题的字典传递给 full_chain。整个流程如下:

    1. 提取表类别:首先通过 table_chain"question" 转换为 "input",然后利用 LLM 提取出与问题相关的类别(预期为 “Music”)。
    2. 映射表名称:根据类别映射出所有与音乐相关的 SQL 表名。
    3. 生成 SQL 查询:利用绑定好的 bound_chain(包含 SQL 模板、数据库 schema 信息等),结合用户的问题和上下文信息,生成一个正确的 SQL 查询。
  • 输出举例:
    假设 LLM 理解问题并生成的 SQL 查询可能为:

    SELECT DISTINCT Genre.Name
    FROM Track
    JOIN Genre ON Track.GenreId = Genre.GenreId
    JOIN Artist ON Track.ArtistId = Artist.ArtistId
    WHERE Artist.Name = 'Alanis Morisette';
    

    (实际生成的 SQL 语句会依赖于 LLM 的理解和数据库的 schema 信息。)


最后运行这个SQL语句

db.run(query)

输出:
在这里插入图片描述

总结

这段代码整体实现了一个智能化的数据查询过程:

  • 输入: 用户问题(如关于 Alanis Morisette 歌曲的查询)。
  • 内部处理:
    1. 利用 LLM 提取相关 SQL 表类别。
    2. 根据类别映射出具体的 SQL 表名称。
    3. 结合数据库的表结构和预定义的 SQL 提示模板,生成正确的 SQL 查询语句。
  • 输出: 一条 SQL 查询语句,用来从数据库中获取答案。

这种链式结构使得整个流程模块化、可扩展:可以分别替换提取逻辑、映射逻辑和 SQL 查询生成逻辑,非常适合在实际应用中自动生成数据库查询。

相关文章:

  • 360个人版和企业版的区别
  • 在C++中如何实现线程安全的队列
  • Qt:窗口
  • CAN总线通信协议学习2——数据链路层之帧格式
  • 【Linux】TCP协议
  • 名词解释:vllm,大模型量化;以及如何在vllm实现大模型量化
  • Vue 系列之:基础知识
  • Java-servlet(二)Java-servlet-Web环境搭建(上)IDEA,maven和tomcat工具下载(附Gitee直接下载)
  • 现今大语言模型性能(准确率)比较
  • 《论企业集成架构设计及应用》审题技巧 - 系统架构设计师
  • 在Ubuntu 22.04 LTS 上安装 MySQL两种方式:在线方式和离线方式
  • 基于Java的AI应用开发实战:从模型训练到服务部署
  • 中间件专栏之Redis篇——Redis的基本IO网络模型
  • 每日OJ_牛客_NC316体育课测验(二)_拓扑排序_C++_Java
  • Typora安装教程(附安装包)Typora下载
  • 小结:BGP 的自动聚合与手动聚合
  • ENSP配置AAA验证
  • 鸿蒙日期格式工具封装及使用
  • Hadoop第一课(配置linux系统)
  • 【软考-架构】1.3、磁盘-输入输出技术-总线
  • 上影节官方海报公布:电影之城,每一帧都是生活
  • “十五五”规划编制工作开展网络征求意见活动
  • 周慧芳任上海交通大学医学院附属上海儿童医学中心党委书记
  • 事关中国,“英伟达正游说美国政府”
  • 央媒:设施老化、应急预案套模板,养老机构消防隐患亟待排查
  • 以军在加沙北部和南部展开大规模地面行动