LangChain-结合GLM+SQL+函数调用实现数据库查询(一)
业务流程
实现步骤
1. 加载数据库配置
在项目的根目录下创建.env 文件,设置文件内容:
DB_HOST=xxx
DB_PORT=3306
DB_USER=xxx
DB_PASSWORD=xxx
DB_NAME=xxx
DB_CHARSET=utf8mb4
加载环境变量,从 .env 文件中读取数据库配置信息
使用 os.getenv() 从环境变量中获取数据库的主机地址、端口、用户名、密码、数据库名和字符集
配置数据库连接参数
使用 quote 对密码进行 URL 编码,确保密码中的特殊字符不会导致连接失败
import os
from urllib.parse import quote
from dotenv import load_dotenvload_dotenv()
B_CONFIG = {"host": os.getenv('DB_HOST'), "port": int(os.getenv('DB_PORT')), "user": os.getenv('DB_USER'),"password": os.getenv('DB_PASSWORD'),"database": os.getenv('DB_NAME'),"charset": os.getenv('DB_CHARSET')
}# 处理特殊字符密码
encoded_password = quote(DB_CONFIG['password'])
构建 MySQL 数据库连接 URI,并连接数据库
构建 SQLAlchemy 的连接 URI,使用 pymysql 作为驱动程序。
设置连接超时时间为 10 秒
创建一个 SQLDatabase 实例,用于与 MySQL 数据库交互
MYSQL_URI = (f"mysql+pymysql://{DB_CONFIG['user']}:{encoded_password}@"f"{DB_CONFIG['host']}:{DB_CONFIG['port']}/"f"{DB_CONFIG['database']}?"f"charset={DB_CONFIG['charset']}&connect_timeout=10"
)db = SQLDatabase.from_uri(MYSQL_URI)
2.初始化大语言模型
初始化一个基于 ChatOpenAI 的模型,使用智谱 AI 的 GLM-4 模型。
配置 API 密钥和基础 URL
llm = ChatOpenAI(temperature=1,model='glm-4-0520',api_key='*****',base_url='https://open.bigmodel.cn/api/paas/v4/'
)
3.定义提示模板
提示模板指导 LLM 根据给定的表结构和用户问题生成 SQL 查询语句
custom_prompt = PromptTemplate.from_template("""
你是一个专业的SQL工程师,请根据以下表结构生成标准SQL查询语句:{table_info}请最多返回 {top_k} 条记录。问题:{input}
SQL查询:
""")
4.SQL 查询链的创建和调用
定义表结构 table_info 和最大返回记录数 top_k。
调用 invoke 方法生成 SQL 查询语句
chian = create_sql_query_chain(llm=llm,db=db,prompt=custom_prompt
)# chian.get_prompts()[0].pretty_print()
# 表结构信息和 top_k 的值
table_info = "这里是表结构信息,例如:member(id, name, tenant_code, deleted)"
top_k = 3
resp = chian.invoke({"input": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?","question": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?",'table_info': table_info,'top_k': top_k
})
5.输出打印
执行生成的 SQL 查询。
使用 ast.literal_eval 安全地解析结果。
输出最终的查询结果。
print('大语言模型生成的SQL:' + resp)
sql = resp.replace('```sql', '').replace('```', '')
print('提取之后的SQL:' + sql)try:result = db.run(sql)# 清洗结果result_list = ast.literal_eval(result)total_count = result_list[0][0]print(f"最终的查询结果为:{total_count}")
except Exception as e:print(f"❌ SQL 执行失败: {str(e)}")
输出结果:
完整代码:
import ast
from langchain.chains.sql_database.query import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
import os
from urllib.parse import quote
from dotenv import load_dotenvload_dotenv()
# 基础配置(建议通过环境变量获取)
DB_CONFIG = {"host": os.getenv('DB_HOST'), # 移除默认值"port": int(os.getenv('DB_PORT')), # 必须转换为整数"user": os.getenv('DB_USER'),"password": os.getenv('DB_PASSWORD'),"database": os.getenv('DB_NAME'),"charset": os.getenv('DB_CHARSET') # 动态获取字符集
}# 处理特殊字符密码
encoded_password = quote(DB_CONFIG['password'])# SQLAlchemy连接URI
MYSQL_URI = (f"mysql+pymysql://{DB_CONFIG['user']}:{encoded_password}@"f"{DB_CONFIG['host']}:{DB_CONFIG['port']}/"f"{DB_CONFIG['database']}?"f"charset={DB_CONFIG['charset']}&connect_timeout=10"
)
# 创建模型
llm = ChatOpenAI(temperature=1,model='glm-4-0520',api_key='****',base_url='https://open.bigmodel.cn/api/paas/v4/'
)db = SQLDatabase.from_uri(MYSQL_URI)
# print(db.dialect)
# print(db.get_usable_table_names())
# print(db.run("SELECT COUNT(1) FROM member where saas_tenant_code ='linefriends' and deleted=0;"))# 自定义提示模板
custom_prompt = PromptTemplate.from_template("""
你是一个专业的SQL工程师,请根据以下表结构生成标准SQL查询语句:{table_info}请最多返回 {top_k} 条记录。问题:{input}
SQL查询:
""")chian = create_sql_query_chain(llm=llm,db=db,prompt=custom_prompt
)# chian.get_prompts()[0].pretty_print()
# 表结构信息和 top_k 的值
table_info = "这里是表结构信息,例如:member(id, name, tenant_code, deleted)"
top_k = 3
resp = chian.invoke({"input": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?","question": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?",'table_info': table_info,'top_k': top_k
})print('大语言模型生成的SQL:' + resp)
sql = resp.replace('```sql', '').replace('```', '')
print('提取之后的SQL:' + sql)try:result = db.run(sql)# 清洗结果result_list = ast.literal_eval(result)total_count = result_list[0][0]print(f"最终的查询结果为:{total_count}")
except Exception as e:print(f"❌ SQL 执行失败: {str(e)}")