MAC-SQL:SQL-Llama 的具体训练流程
要理解SQL-Llama的训练任务数量和具体训练方式,需结合论文中对Agent-Instruct数据集设计和多任务监督微调流程的描述,以下是分点解析:
一、SQL-Llama的“3个指令任务”是什么?
论文中提到的 N=3 并非指“3条独立指令”,而是指3类核心Agent任务——这3类任务对应MAC-SQL框架中3个Agent的核心功能,构成了SQL-Llama的微调任务体系。这3类任务完全匹配框架需求,确保SQL-Llama能替代GPT-4完成所有Agent工作,具体如下:
| 任务类别(N=3) | 对应Agent | 任务目标 | 示例场景 |
|---|---|---|---|
| 1. 数据库简化任务 | Selector | 从完整数据库中筛选出与用户问题相关的最小子数据库(含相关表、列),避免无关信息干扰 | 当用户问“特许学校的SAT优秀率”时,筛选出frpm(含特许学校标识)、satscores(含SAT成绩)表,剔除无关的schools表中“街道、电话”等列 |
| 2. 问题分解与SQL生成任务 | Decomposer | 将复杂问题拆解为子问题,通过思维链(CoT)生成每个子问题的SQL,最终整合为完整SQL | 用户问“SAT优秀率高于平均的特许学校名称”,拆解为: 子问题1:计算特许学校的SAT优秀率平均值(生成对应SQL) 子问题2:筛选出优秀率高于该平均值的学校名称(生成最终SQL) |
| 3. SQL错误修正任务 | Refiner | 接收外部工具(如SQLite)反馈的错误信息(如语法错误、表名错误),修正错误SQL | 若生成的SQL因“冗余括号”报错(near "(": syntax error),自动删除多余括号,确保SQL可执行且逻辑不变 |
二、SQL-Llama的具体训练流程:从数据构建到微调执行
SQL-Llama基于Code Llama 7B基座模型,通过“构建Agent-Instruct数据集→多任务监督微调”两步实现,核心是让模型学习“如何完成上述3类Agent任务”,具体过程可拆解为4个关键步骤:
步骤1:构建Agent-Instruct训练数据集(核心输入)
为了让模型学习3类任务,论文首先构建了包含这3类任务的Agent-Instruct数据集(共10,000条高质量数据),数据来源和筛选逻辑如下:
- 数据来源:基于BIRD和Spider两个Text-to-SQL基准数据集的训练集,由GPT-4模拟MAC-SQL框架的3个Agent工作流程,生成对应任务的“指令-输出”样本。
- 例如,针对“数据库简化任务”,生成“用户问题+完整数据库schema→筛选后的子数据库schema”的样本;
- 针对“SQL修正任务”,生成“错误SQL+错误信息→正确SQL”的样本。
- 数据筛选:过滤掉SQL输出错误、任务逻辑不匹配的样本,确保数据质量,最终得到覆盖3类任务、匹配BIRD/Spider数据分布的数据集
D = {D₁, D₂, D₃}(D₁对应任务1,D₂对应任务2,D₃对应任务3)。
步骤2:定义多任务微调目标(损失函数)
训练的核心是让SQL-Llama同时学习3类任务,论文采用多任务联合损失函数,让模型在3类任务样本上共同优化,公式如下(对应论文公式):
L=−∑i=13EQ,Si,K,Yi∼Di[logP(Yi∣Q,Si,K;M)]\mathcal{L}=-\sum_{i=1}^{3} \mathbb{E}_{\mathcal{Q}, S^{i}, \mathcal{K}, \mathcal{Y}^{i} \sim \mathcal{D}_i}\left[log P\left(\mathcal{Y}^{i} | \mathcal{Q}, \mathcal{S}^{i}, \mathcal{K} ; \mathcal{M}\right)\right]L=−i=1∑3EQ,Si,K,Yi∼Di[logP(Yi∣Q,Si,K;M)]
- 符号含义:
i=1,2,3:对应3类Agent任务;Q:用户自然语言问题;Sⁱ:任务i所需的数据库schema(如任务1的“完整schema”、任务2的“筛选后schema”);K:外部知识(如“SAT优秀率=NumGE1500/NumTstTakr”);Yⁱ:任务i的目标输出(如任务1的“子数据库schema”、任务2的“最终SQL”);log P(...):模型预测输出Yⁱ的对数概率,损失函数通过最小化“负对数概率”,让模型更大概率生成正确的任务输出。
步骤3:执行监督微调(基于Code Llama 7B)
以Code Llama 7B为基座模型,采用全参数监督微调(所有模型参数均更新),训练过程遵循以下逻辑:
- 输入格式统一:将3类任务的样本统一为“指令模板+输入信息”的格式,确保模型能识别任务类型。
- 示例(任务2:问题分解与SQL生成):
输入模板:[数据库schema: {S²}] [用户问题: {Q}] [外部知识: {K}] 请拆解问题并生成最终SQL:
目标输出:{Y²}(子问题列表+最终SQL)。
- 示例(任务2:问题分解与SQL生成):
- 训练过程:将Agent-Instruct数据集的3类样本混合输入模型,按批次更新模型参数,优化上述多任务损失函数。训练中需平衡模型复杂度(7B参数)与性能,避免过拟合(通过数据多样化、梯度裁剪等策略)。
步骤4:训练效果验证
微调完成后,SQL-Llama需验证是否能完成3类Agent任务:
- 替代Selector:能根据问题筛选出最小子数据库;
- 替代Decomposer:能拆解复杂问题并生成正确SQL;
- 替代Refiner:能修正错误SQL。
实验结果显示,SQL-Llama在BIRD数据集的开发集上达到43.94%的执行准确率,接近GPT-4的46.35%,证明其能有效学习3类任务并适配MAC-SQL框架。
三、举个具体例子:SQL-Llama如何学习“SQL修正任务”?
假设训练集中有一条“SQL修正任务”的样本,具体如下:
- 任务输入(Q+Sⁱ+K+错误信息):
- 用户问题:List school names of charter schools with an SAT excellence rate over the average.
- 数据库schema(S³):
frpm表(含CDSCode、Charter School(Y/N))、satscores表(含cds、sname、NumGE1500、NumTstTakr); - 外部知识(K):SAT_Excellence_Rate = NumGE1500 / NumTstTakr;
- 错误SQL:
SELECT T2.sname FROM frpm AS T1 JOIN satscores AS T2 ON T1.CDSCode = T2.cds WHERE T1.Charter School(Y/N) = 1 AND (NumGE1500/NumTstTakr) > ((SELECT AVG(NumGE1500/NumTstTakr) FROM frpm AS T1 JOIN satscores AS T2 ON T1.CDSCode = T2.cds WHERE T1.Charter School(Y/N) = 1)); - 错误信息:
SQLite error: near "(": syntax error(冗余括号导致)。
- 目标输出(Y³):
修正后的SQL:SELECT T2.sname FROM frpm AS T1 JOIN satscores AS T2 ON T1.CDSCode = T2.cds WHERE T1.Charter School(Y/N) = 1 AND (NumGE1500/NumTstTakr) > (SELECT AVG(NumGE1500/NumTstTakr) FROM frpm AS T1 JOIN satscores AS T2 ON T1.CDSCode = T2.cds WHERE T1.Charter School(Y/N) = 1)(删除多余的内层括号)。
训练时,模型会学习“错误SQL+错误信息→正确SQL”的映射关系,当后续遇到类似括号冗余错误时,能自动修正。
总结
- SQL-Llama的“3个指令任务”是3类Agent功能任务(数据库简化、问题分解与SQL生成、SQL修正),而非3条独立指令;
- 训练核心是“构建3类任务的高质量数据集→通过多任务联合损失函数微调Code Llama 7B”,让模型学会完成MAC-SQL框架的所有Agent工作;
- 最终实现“轻量级模型(7B参数)接近GPT-4性能”的目标,为MAC-SQL框架提供开源、低成本的模型选择。
