NL2SQL模型应用实践-解决上百张表筛选问题
用一个最简单的例子说明NL2SQL模型如何使用,假设我们只有3张表,并通过分步演示整个流程。
数据库表结构(3张表)
-
学生表(students)
CREATE TABLE students (student_id INT PRIMARY KEY,name VARCHAR(50),age INT,class_id INT -- 外键关联classes表 );
student_id name age class_id 1 张三 18 101 2 李四 19 102 -
课程表(courses)
CREATE TABLE courses (course_id INT PRIMARY KEY,course_name VARCHAR(50),teacher VARCHAR(50) );
course_id course_name teacher 201 数学 王老师 202 英语 李老师 -
选课表(enrollments)
CREATE TABLE enrollments (enrollment_id INT PRIMARY KEY,student_id INT, -- 外键关联students表course_id INT -- 外键关联courses表 );
enrollment_id student_id course_id 301 1 201 302 2 202
步骤1:用户输入自然语言问题
问题:
“找出选了数学课的所有学生姓名”
步骤2:自动选择相关表
通过简单的关键词匹配(实际应用可用更复杂的NLP方法):
- “数学课” → 匹配
courses
表的course_name
字段 - “学生姓名” → 匹配
students
表的name
字段 - 需要关联表 →
enrollments
表包含学生和课程的关联
最终选择表:
students
, courses
, enrollments
步骤3:生成表结构描述(Prompt)
将表结构转换为NL2SQL模型能理解的文本描述:
数据库表结构:
1. 表名: students字段:- student_id (整数): 学生ID- name (字符串): 学生姓名- age (整数): 学生年龄- class_id (整数): 班级ID2. 表名: courses字段:- course_id (整数): 课程ID- course_name (字符串): 课程名称- teacher (字符串): 授课老师3. 表名: enrollments字段:- enrollment_id (整数): 选课记录ID- student_id (整数): 关联students表- course_id (整数): 关联courses表表关联关系:- enrollments.student_id → students.student_id- enrollments.course_id → courses.course_id
步骤4:NL2SQL模型输入(示例)
输入 = """
问题:找出选了数学课的所有学生姓名数据库表结构:
[上述表结构描述]
"""
步骤5:模型输出SQL
SELECT s.name
FROM students s
JOIN enrollments e ON s.student_id = e.student_id
JOIN courses c ON e.course_id = c.course_id
WHERE c.course_name = '数学';
步骤6:执行SQL得到结果
name |
---|
张三 |
关键点总结
- 表选择:通过问题中的关键词(“学生”、“数学课”)锁定
students
、courses
和关联表enrollments
- Prompt构建:清晰描述表字段和关联关系
- 模型输出:生成正确的三表JOIN操作和过滤条件
对比:如果没有表选择步骤
假设将所有表结构都输入模型:
数据库表结构:
[包含数百张表的完整描述...]
会导致:
- 模型输入过长
- 生成错误SQL的概率增加
- 计算资源浪费
当面对数百张表时,高效筛选与用户问题相关的表是NL2SQL应用的关键。
一、核心问题分析
NL2SQL 模型的输入通常包括:
- 自然语言问题:用户的查询语句(如 “查询 2023 年上海地区的销售额”)。
- 数据库模式(Schema):表名、字段名、字段类型、表间关系等。
当表数量庞大时,直接输入全量表结构会导致:
- 模型过载:输入维度爆炸,影响生成 SQL 的准确性。
- 效率低下:解析全量数据耗时,无法满足实时查询需求。
核心目标:通过动态筛选相关表,仅向模型提供与问题相关的少量表结构。
二、整体思路
三、表筛选策略
1. 元数据匹配(轻量级)
建立表/列名的语义索引:
from sklearn.feature_extraction.text import TfidfVectorizer# 元数据示例
table_metadata = {"sales_records": "存储客户交易数据,包括订单号、金额、日期","customer_info": "客户基本信息表,含姓名、电话、地址","product_catalog": "商品信息表,含商品ID、名称、类别"
}# 创建元数据索引
vectorizer = TfidfVectorizer()
corpus = list(table_metadata.values())
tfidf_matrix = vectorizer.fit_transform(corpus)def find_relevant_tables(query, top_k=3):query_vec = vectorizer.transform([query])cosine_sim = (query_vec * tfidf_matrix.T).toarray()[0]sorted_indices = cosine_sim.argsort()[::-1][:top_k]return [list(table_metadata.keys())[i] for i in sorted_indices]# 使用示例
query = "找出上海客户的最近订单金额"
tables = find_relevant_tables(query) # 返回 ['sales_records', 'customer_info']
2. 向量检索(高精度)
使用Embedding模型增强语义理解:
from sentence_transformers import SentenceTransformer
import numpy as np# 加载预训练模型
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')# 生成表描述向量
table_descriptions = list(table_metadata.values())
table_embeddings = model.encode(table_descriptions)def vector_based_search(query, top_k=3):query_embedding = model.encode([query])similarity = np.dot(query_embedding, table_embeddings.T)[0]top_indices = similarity.argsort()[::-1][:top_k]return [list(table_metadata.keys())[i] for i in top_indices]# 使用示例
query = "统计电子产品季度销售额"
tables = vector_based_search(query) # 返回 ['sales_records', 'product_catalog']
3. 图关系分析(处理复杂查询)
构建表关系图谱实现关联表发现:
import networkx as nx# 构建表关系图
schema_graph = nx.Graph()
schema_graph.add_edges_from([("sales_records", "customer_info", {"key": "customer_id"}),("sales_records", "product_catalog", {"key": "product_id"}),("customer_info", "region_data", {"key": "region_id"})
])def expand_related_tables(seed_tables, depth=1):related = set(seed_tables)for table in seed_tables:neighbors = nx.neighbors(schema_graph, table)related.update(neighbors)return list(related)# 使用示例
seed_tables = ["customer_info"] # 初步筛选的表
full_set = expand_related_tables(seed_tables) # 返回 ['customer_info', 'sales_records', 'region_data']
四、NL2SQL模型集成方案
1. 完整工作流
def nl2sql_with_table_selection(user_query, all_tables):# 步骤1: 初步表筛选seed_tables = vector_based_search(user_query, top_k=2)# 步骤2: 关系扩展selected_tables = expand_related_tables(seed_tables)# 步骤3: 获取表结构table_schemas = get_table_schemas(selected_tables)# 步骤4: NL2SQL推理sql = nl2sql_model.generate(question=user_query,schema=table_schemas)return sql# 示例调用
user_query = "计算上海客户购买手机的平均金额"
sql = nl2sql_with_table_selection(user_query, all_tables)
2. 表结构描述优化
生成Schema描述:
def get_table_schemas(table_names):schemas = []for name in table_names:# 获取列信息columns = database.get_columns(name)# 构建描述schema_desc = f"表名: {name}\n描述: {table_metadata[name]}\n字段:"for col in columns:schema_desc += f"\n - {col['name']} ({col['type']}): {col['comment']}"# 添加主外键关系fks = database.get_foreign_keys(name)if fks:schema_desc += "\n关联关系:"for fk in fks:schema_desc += f"\n - {fk['from']} → {fk['to_table']}.{fk['to_column']}"schemas.append(schema_desc)return "\n\n".join(schemas)
五、解决方案架构
组件说明:
- 元数据索引:Elasticsearch存储表/列描述
- 向量数据库:FAISS/Pinecone存储表描述向量
- 图数据库:Neo4j管理表关系
- 表筛选服务:综合多种策略输出精简表集合
- NL2SQL模型:基于T5/Codex的微调模型
六、实践案例
问题:
“查询2023年Q3华东地区销售额Top 10的电子产品及其供应商”
表筛选过程:
- 向量检索识别核心表:
sales_records
,product_catalog
- 图关系扩展:
sales_records
→region_data
(地区信息)product_catalog
→supplier_info
(供应商)
- 最终表集合(4张):
sales_records
,product_catalog
,region_data
,supplier_info
NL2SQL输入:
问题:查询2023年Q3华东地区销售额Top 10的电子产品及其供应商表结构:
1. 表名: sales_records描述: 销售记录表字段:- order_id (int): 订单ID- product_id (int): 商品ID- region_id (int): 地区ID- sale_amount (float): 销售金额- sale_date (date): 销售日期2. 表名: product_catalog描述: 商品信息表字段:- product_id (int): 商品ID- product_name (varchar): 商品名称- category (varchar): 商品类别- supplier_id (int): 供应商ID3. 表名: region_data描述: 地区信息表字段:- region_id (int): 地区ID- region_name (varchar): 地区名称4. 表名: supplier_info描述: 供应商信息表字段:- supplier_id (int): 供应商ID- supplier_name (varchar): 供应商名称关联关系:- sales_records.product_id → product_catalog.product_id- sales_records.region_id → region_data.region_id- product_catalog.supplier_id → supplier_info.supplier_id
生成SQL:
SELECT p.product_name,s.supplier_name,SUM(sr.sale_amount) AS total_sales
FROM sales_records sr
JOIN product_catalog p ON sr.product_id = p.product_id
JOIN region_data r ON sr.region_id = r.region_id
JOIN supplier_info s ON p.supplier_id = s.supplier_id
WHERE r.region_name = '华东'AND sr.sale_date BETWEEN '2023-07-01' AND '2023-09-30'AND p.category = '电子产品'
GROUP BY p.product_name, s.supplier_name
ORDER BY total_sales DESC
LIMIT 10;
通过组合元数据匹配、向量检索和图关系分析,可有效从数百张表中筛选出3-5张核心表,既满足NL2SQL模型的输入要求,又确保生成SQL的准确性。实际应用中结合缓存机制优化性能。