【论文解读】Toolformer: 语言模型自学使用工具
1st author: Timo Schick - Google Scholar
paper: Toolformer: Language Models Can Teach Themselves to Use Tools | OpenReview NeurIPS 2023 oral
code: lucidrains/toolformer-pytorch: Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
5. 总结 (结果先行)
Toolformer 提出了一种自监督框架,使语言模型能够学会自主使用外部工具。通过以“降低未来词元预测损失”为目标来过滤和学习 API 调用,Toolformer 在不牺牲核心语言能力的前提下,显著增强了模型在知识获取、数学计算、实时信息处理等方面的零样本能力,甚至能让较小模型匹敌数倍于其参数量的更大模型。
前瞻与思考:
- 工具链的挑战: 目前的 Toolformer 独立采样每个 API 调用,尚不支持将一个工具的输出作为另一个工具的输入(工具链)。这是未来一个重要的扩展方向。
- 交互式工具使用: 对于搜索引擎这类可能返回大量结果的工具,允许模型进行多轮交互(如 уточнение запроса)将是关键。
- API 调用成本: 当前模型未考虑调用不同 API 的计算或经济成本。
- 鲁棒性与泛化: 模型对提示词的敏感性以及如何更好地泛化到未见过的工具或组合值得进一步研究。
- 更深层次的“理解”: Toolformer 通过降低预测损失来学习,这是否等同于对工具功能和适用场景的深层理解,仍是一个开放问题。
1. 思想
语言模型(LMs)在文本生成和理解上已展现惊人能力,但其固有缺陷——如知识局限于训练数据、易产生事实幻觉、数学推理薄弱、缺乏对动态世界的实时感知——限制了其在更复杂和真实场景中的应用。
Toolformer 的核心思想是:让语言模型通过自监督学习,自主学会调用外部工具(APIs)来弥补自身能力的不足 ——既保留 LMs 的强大语言能力,又获得工具的精确性和实时性。 让 LM 自己决定何时(when)、为何(what arguments)、以及如何(how to incorporate results)使用何种工具。
其本质是探索一种让 LMs 从“封闭世界”走向“开放世界”,与外部环境交互并从中学习的有效路径。
2. 方法
Toolformer 的方法论体现了巧妙的自监督设计,主要包含以下 3 个步骤,针对每个工具独立进行:
-
API 调用采样 (Sampling API Calls):
- 目的: 对于数据集 C \mathcal{C} C 中的每个文本 x x x,识别潜在有益的 API 调用位置和内容。
- 过程:
- 给定一个纯文本数据集 C = { x ( 1 ) , x ( 2 ) , … , x ( N ) } \mathcal{C} = \{x^{(1)}, x^{(2)}, \dots, x^{(N)}\} C={x(1),x(2),…,x(N)},其中每个 x x x 是一个文本序列,例如 x = ( x 1 , x 2 , … , x n ) x = (x_1, x_2, \dots, x_n) x=(x1,x2,…,xn), x j x_j xj 代表第 j j j 个词元 (token)。和一些特定的 API(例如,计算器API、问答API)。
- 通过少量人工编写的 few-shot 示例,构建一个提示 (prompt),引导一个预训练的 LM 在文本的潜在位置 i i i 生成候选的 API 调用。例如,在文本 “… 1400名参与者中,400人…” 处,LM 可能会被引导生成
[Calculator(400/1400)]
这样的调用。 - LM 会为文本中每个位置 i i i 生成一个代表特殊 token (
<API>
) 的概率 P M ( <API> ∣ P ( x ) , x 1 : i − 1 ) P_M(\text{<API>} | P(x), x_{1:i-1}) PM(<API>∣P(x),x1:i−1),并以超过某个阈值的位置 i i i 多作为候选的 API 调用起始位置。 - 在每个选定的位置 i i i,让模型 M M M 以 P api ( x ) , x 1 : i − 1 , <API> P_{\text{api}}(x), x_{1:i-1}, \text{<API>} Papi(x),x1:i−1,<API> 为前缀,继续生成 API 调用的具体内容(即 API 名称和输入参数),直到生成 API 调用的结束标记 </API> \text{</API>} </API> (如论文中实际使用的
]
)。这样,我们就得到了一系列候选的 API 调用 c 1 , c 2 , … , c m c_1, c_2, \dots, c_m c1,c2,…,cm。每个调用 c c c 可以看作是一个包含 API 名称 a c a_c ac 和输入 i c i_c ic 的元组 ( a c , i c ) (a_c, i_c) (ac,ic)。
-
API 调用执行 (Executing API Calls):
- 目的: 执行上一步采样的所有 API 调用 c j c_j cj,获得相应的文本结果 r j r_j rj。
- 过程:
- 对每个采样到的 API 调用 c j = ( a c , i c ) c_j = (a_c, i_c) cj=(ac,ic),我们执行它。例如,如果 c j c_j cj 是一个计算器调用
Calculator("400/1400")
,我们就实际计算 400 / 1400 400/1400 400/1400。 - 执行后得到文本形式的返回结果 r j r_j rj。例如,对于
Calculator("400/1400")
,结果 r j r_j rj 可能是 “0.29”。
- 对每个采样到的 API 调用 c j = ( a c , i c ) c_j = (a_c, i_c) cj=(ac,ic),我们执行它。例如,如果 c j c_j cj 是一个计算器调用
-
API 调用过滤 (Filtering API Calls):
-
目的: 只保留那些能够显著帮助模型 M M M 预测原始文本 x x x 中后续词元的 API 调用。(这是自监督学习的核心)
-
过程:
-
我们将一个 API 调用 c c c(包含其输入 i c i_c ic)及其返回结果 r r r 线性化为文本序列。我们定义两种形式:
- e ( c ) = <API> a c ( i c ) </API> e(c) = \text{<API>} a_c(i_c) \text{</API>} e(c)=<API>ac(ic)</API>:仅包含 API 调用本身。
- e ( c , r ) = <API> a c ( i c ) → r </API> e(c,r) = \text{<API>} a_c(i_c) \rightarrow r \text{</API>} e(c,r)=<API>ac(ic)→r</API>:包含 API 调用及其结果 r r r。“ → \rightarrow →”是一个特殊的分隔符。
-
对于在文本 x x x 的位置 i i i 插入的 API 调用 c c c 及其结果 r r r,我们计算模型 M M M 在给定前缀 z z z 的情况下,对后续文本 x i : n = ( x i , x i + 1 , … , x n ) x_{i:n} = (x_i, x_{i+1}, \dots, x_n) xi:n=(xi,xi+1,…,xn) 的加权交叉熵损失 L i ( z ) L_i(z) Li(z):
L i ( z ) = − ∑ j = i n w j − i log P M ( x j ∣ z , x 1 : j − 1 ) L_i(z) = -\sum_{j=i}^{n} w_{j-i} \log P_M(x_j | z, x_{1:j-1}) Li(z)=−j=i∑nwj−ilogPM(xj∣z,x1:j−1)- 符号解释:
- z z z: 代表插入到 x 1 : i − 1 x_{1:i-1} x1:i−1 和 x i : n x_{i:n} xi:n 之间的前缀序列。它可以是 e ( c , r ) e(c,r) e(c,r)(调用+结果), e ( c , ϵ ) e(c, \epsilon) e(c,ϵ)(调用+空结果),或者 ϵ \epsilon ϵ(完全没有调用)。
- ϵ \epsilon ϵ: 代表空序列。
- w j − i w_{j-i} wj−i: 是一个权重,通常随着 j − i j-i j−i (即词元 x j x_j xj 与 API 调用位置 i i i 的距离) 的增大而减小。这意味着模型更关注紧随 API 调用之后的词元的预测。
- 符号解释:
-
我们比较以下两种损失:
- L i + = L i ( e ( c , r ) ) L_i^+ = L_i(e(c,r)) Li+=Li(e(c,r)): 在位置 i i i 提供了完整的 API 调用及其结果 r r r 时的损失。
- L i − = min ( L i ( ϵ ) , L i ( e ( c , ϵ ) ) ) L_i^- = \min(L_i(\epsilon), L_i(e(c, \epsilon))) Li−=min(Li(ϵ),Li(e(c,ϵ))): 不进行任何 API 调用 (损失为 L i ( ϵ ) L_i(\epsilon) Li(ϵ)),或者进行了 API 调用但没有提供结果 (损失为 L i ( e ( c , ϵ ) ) L_i(e(c, \epsilon)) Li(e(c,ϵ))) 时的最小损失。
-
只有当提供完整的 API 调用和结果能显著降低损失时,即 L i − − L i + ≥ τ f L_i^- - L_i^+ \ge \tau_f Li−−Li+≥τf (其中 τ f \tau_f τf 是一个预设的过滤阈值),我们才认为这个 API 调用 c c c 在位置 i i i 是有用的,并将其保留。
-
-
模型微调 (Model Finetuning):
- 将所有通过过滤的、有用的 API 调用及其结果整合回原始文本,形成一个增强的数据集 C ∗ \mathcal{C}^* C∗。
- 在 C ∗ \mathcal{C}^* C∗ 上微调原始的 LM。
推理阶段: 当微调后的 Toolformer 生成文本时,如果它预测出特殊序列 “ → \rightarrow →”,表示它期望一个 API 的返回结果。此时,解码过程暂停,系统执行相应的 API 调用,并将返回结果插入文本流,然后继续解码。
3. 优势
Toolformer 的设计带来了几个显著优势:
- 自监督学习: 无需大规模人工标注 API 调用数据,极大降低了数据获取成本,并使得模型能从其认为“有用”的角度学习,而非人类预设的“有用”。
- 通用性与自主性: 模型自主决定何时、何地、如何使用何种工具,而非针对特定任务硬编码工具使用逻辑。这使得 Toolformer 具备更广泛的适用性。
- 性能提升: 在多种下游任务上,尤其是在需要事实查找、计算或最新信息的场景中,Toolformer(基于 GPT-J 6.7B)的零样本性能显著优于同等规模甚至远大于其的 LMs(如 GPT-3 175B 的部分任务)。
- 保持核心语言能力: 由于微调数据是在原始文本基础上增加 API 调用,模型在学习使用工具的同时,其核心的语言建模能力(如 PPL)并未受到损害。
- 工具多样性: 该框架可集成多种类型的工具,如问答系统、计算器、搜索引擎、翻译系统和日历。
4. 实验
论文进行了一系列详尽的实验,以验证 Toolformer 的有效性:
- 基础模型: 主要使用 GPT-J (6.7B 参数) 作为基础 LM。
- 数据集: 使用 CCNet 的子集进行 API 调用采样和微调。
- 评估任务:
- LAMA (事实知识探测): Toolformer (使用 QA 工具) 在 SQuAD, Google-RE, T-REx 子集上显著优于 GPT-J,并与 GPT-3 (175B) 表现相当或更好。
- 数学推理 (ASDiv, SVAMP, MAWPS): Toolformer (使用计算器) 性能远超所有基线,包括 GPT-3 (175B)。
- 问答 (WebQS, NQ, TriviaQA): Toolformer (使用维基百科搜索) 表现优于同尺寸模型,但不及 GPT-3,可能受限于搜索工具的简单性和交互性。
- 多语言问答 (MLQA): Toolformer (使用翻译工具) 在所有语言上均有提升,但由于 CCNet 微调可能引入分布漂移,其表现并非总是优于原始 GPT-J。
- 时序数据集 (TEMPLAMA, DATESET): Toolformer (主要使用维基搜索和QA工具处理 TEMPLAMA,使用日历工具处理 DATESET) 均优于基线。
- 语言模型能力: 在 WikiText 和 CCNet 验证集上评估困惑度 (PPL),Toolformer (禁用 API 调用时) 的 PPL 与在纯 CCNet 上微调的 GPT-J 相当,表明其语言能力未受损。
- 模型规模扩展性 (Scaling Laws): 实验表明,模型利用工具的能力大约在 775M 参数规模时开始显现,并随模型增大而增强。