当前位置: 首页 > news >正文

【论文阅读】Thinkless: LLM Learns When to Think

Thinkless: LLM Learns When to Think

    • Method
      • Distillation for Warm-up
      • Learning When to Think via Decoupled GRPO
      • 奖励设计
      • 解耦策略优化
      • 方法总结

Thinkless: LLM Learns When to Think 这篇文章介绍了名为Thinkless的模型,该模型旨在通过自适应选择思考和非思考推理模式来提高推理语言模型(LLM)的效率和性能。Thinkless的核心方法是Decoupled Group Relative Policy Optimization (Decoupled GRPO),它将混合推理目标分解为两个组件:模式选择(Mode Selection)和准确性提升(Accuracy Improvement)。模式选择决定了模型何时采用长链推理,而准确性提升则在选定的推理模式下优化响应内容以提高答案的正确性。

项目地址:https://github.com/VainF/Thinkless

在这里插入图片描述
本篇博客仅聚焦文章的方法部分

Method

文章分两个阶段实现:

  • (1)Distillation for Warm-up,对预训练的推理模型进行微调以统一两种推理风格,
  • (2)Reinforcement Learning with DeGRPO

Distillation for Warm-up

在作者提出的框架中,第一步是构建一个能够生成短回应和长回应的模型 πθ\pi_\thetaπθ。作者利用两个预训练的专家模型进行蒸馏:一个推理模型 πthink\pi_{\text{think}}πthink,训练用于通过逐步推理生成详细的思维链;以及一个指令跟随模型 πshort\pi_{\text{short}}πshort,优化用于生成与用户意图对齐的简洁回答。给定一个提示语料库 X={xi}i=1NX = \{x_i\}_{i=1}^NX={xi}i=1N,我们使用这些模型生成一个合成的配对数据集:

Ddistill={(xi,<think>athinki,<short>ashorti)}i=1N,D_{\text{distill}} = \{ (x_i, <\text{think}> a_{\text{think}}^i, <\text{short}> a_{\text{short}}^i) \}_{i=1}^N, Ddistill={(xi,<think>athinki,<short>ashorti)}i=1N,

其中 athinki=πlong(xi)a_{\text{think}}^i = \pi_{\text{long}}(x_i)athinki=πlong(xi)ashorti=πshort(xi)a_{\text{short}}^i = \pi_{\text{short}}(x_i)ashorti=πshort(xi)。每个response都以前缀控制标记 c∈C={<short>,<think>}c \in C = \{<\text{short}>, <\text{think}>\}cC={<short>,<think>} 开头,以指示预期的推理风格。然后,我们通过监督细调(SFT)在该数据集上细调目标推理模型 πθ\pi_\thetaπθ。目标是学习一个基于控制标记的多风格回应分布。这一蒸馏阶段确保模型能够以高保真度生成两种类型的回应,同时配对数据集的构建确保模型的回应分布是平衡的,这有助于后续的强化学习过程探索不同的解决方案。

在这里插入图片描述

Learning When to Think via Decoupled GRPO

在蒸馏阶段之后,模型可以生成长回应和短回应。然而,它仍然缺乏一种机制来决定哪种推理模式适合特定的输入 xxx。为了提供这种能力,作者将模式选择视为一个强化学习问题,并优化策略 πθ(c,a∣x)=πθ(c∣x)πθ(a∣x,c)\pi_\theta(c, a | x) = \pi_\theta(c | x) \pi_\theta(a | x, c)πθ(c,ax)=πθ(cx)πθ(ax,c),其中第一个标记 c∈C={<short>,<think>}c \in C = \{<\text{short}>, <\text{think}>\}cC={<short>,<think>} 作为控制标记,决定了推理模式,而后续的标记 (ai,1,…,ai,Ti)(a_{i,1}, \ldots, a_{i,T_i})(ai,1,,ai,Ti) 构成了生成的回应。为了方便表示,我们将第 iii 个样本的整个长度为 Ti+1T_i + 1Ti+1 的序列表示为 ai=(ai,0,…,ai,Ti)a_i = (a_{i,0}, \ldots, a_{i,T_i})ai=(ai,0,,ai,Ti),其中 ai,0∈Ca_{i,0} \in Cai,0C 是控制标记。

奖励设计

y∗y^*y 表示输入 xxx 对应的正确答案。我们考虑一个最小设计的奖励函数 r(a,y∗,c)r(a, y^*, c)r(a,y,c),其奖励值如下:
r(a,y∗,c)={1.0,if c=<short>and Extract-Answer(a)=y∗,1.0−γ,if c=<think>and Extract-Answer(a)=y∗,−1.0,if Extract-Answer(a)≠y∗,r(a, y^*, c) = \begin{cases} 1.0, & \text{if } c = <\text{short}> \text{ and } \text{Extract-Answer}(a) = y^*, \\ 1.0 - \gamma, & \text{if } c = <\text{think}> \text{ and } \text{Extract-Answer}(a) = y^*, \\ -1.0, & \text{if } \text{Extract-Answer}(a) \neq y^*, \end{cases} r(a,y,c)=1.0,1.0γ,1.0,if c=<short> and Extract-Answer(a)=y,if c=<think> and Extract-Answer(a)=y,if Extract-Answer(a)=y,
其中 1>γ>01 > \gamma > 01>γ>0 引入了对短正确答案的偏好,相对于长回应。

解耦策略优化

基于简单的奖励函数,作者采用基于 GRPO 的框架进行训练。设 {ai}i=1G\{a_i\}_{i=1}^G{ai}i=1G 表示从当前策略 πθold\pi_{\theta_{\text{old}}}πθold 中采样的一个 mini-batch 轨迹。目标函数定义为:

JGRPO(θ)=Ex,ai[1G∑i=1G(1Ti+1∑t=0TiLi,t(θ)−βDKL[πθ(⋅∣x)∥πref(⋅∣x)])](1)J_{\text{GRPO}}(\theta) = \mathbb{E}_{x, a_i} \left[ \frac{1}{G} \sum_{i=1}^G \left( \frac{1}{T_i + 1} \sum_{t=0}^{T_i} L_{i,t}(\theta) - \beta D_{\text{KL}} \left[ \pi_\theta(\cdot | x) \parallel \pi_{\text{ref}}(\cdot | x) \right] \right) \right] \tag{1} JGRPO(θ)=Ex,ai[G1i=1G(Ti+11t=0TiLi,t(θ)βDKL[πθ(x)πref(x)])](1)

其中 Li,t(θ)L_{i,t}(\theta)Li,t(θ) 表示token-level surrogate loss,形式上定义为:

Li,t(θ)=min⁡(πθ(ai,t∣x,ai,<t)πθold(ai,t∣x,ai,<t)A^i,t,clip(πθ(ai,t∣x,ai,<t)πθold(ai,t∣x,ai,<t),1−ϵ,1+ϵ)A^i,t)(2)L_{i,t}(\theta) = \min \left( \frac{\pi_\theta(a_{i,t} | x, a_{i,<t})}{\pi_{\theta_{\text{old}}}(a_{i,t} | x, a_{i,<t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_\theta(a_{i,t} | x, a_{i,<t})}{\pi_{\theta_{\text{old}}}(a_{i,t} | x, a_{i,<t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) \tag{2} Li,t(θ)=min(πθold(ai,tx,ai,<t)πθ(ai,tx,ai,<t)A^i,t,clip(πθold(ai,tx,ai,<t)πθ(ai,tx,ai,<t),1ϵ,1+ϵ)A^i,t)(2)

在本工作中,作者使用 A^i,t=r−mean(r)\hat{A}_{i,t} = r - \text{mean}(r)A^i,t=rmean(r) 计算相对优势,这一选择是基于观察到训练数据包含不同难度的问题,使用标准差归一化可能会引入偏差。

当应用于训练时,方程(1)(1)(1)的目标函数有两个目的:

  • 学习适当的控制标记以进行模式选择,
  • 提高回应标记的准确性:
    1Ti+1∑t=0TiLi,t(θ)=1Ti+1Li,0(θ)(Control Tokens)+1Ti+1∑t=1TiLi,t(θ)(Response Tokens)(3)\frac{1}{T_i + 1} \sum_{t=0}^{T_i} L_{i,t}(\theta) = \frac{1}{T_i + 1} L_{i,0}(\theta) \quad \text{(Control Tokens)} + \frac{1}{T_i + 1} \sum_{t=1}^{T_i} L_{i,t}(\theta) \quad \text{(Response Tokens)} \tag{3} Ti+11t=0TiLi,t(θ)=Ti+11Li,0(θ)(Control Tokens)+Ti+11t=1TiLi,t(θ)(Response Tokens)(3)

对于模式选择,回应风格基于第一个标记 ai,0a_{i,0}ai,0,该标记在之前的蒸馏阶段进行训练。因此,调整这个单一控制标记的概率就足以在不同的推理模式之间切换。因此,这个标记控制推理模式的学习。对于回应准确性,优化目标是提高剩余标记 ai,1:Tia_{i,1:T_i}ai,1:Ti 的生成。然而,上述方程 (3)(3)(3)在优化过程中引入了两种不平衡:

  1. 模式-准确性不平衡 - 每个轨迹只包含一个控制标记,但有 TiT_iTi 个回应标记,不成比例地减少了模式选择的影响;

  2. 长短不平衡 - 更长的序列 Tthinki≫TshortiT_{\text{think}}^i \gg T_{\text{short}}^iTthinkiTshorti 进一步抑制了控制标记的梯度贡献,因为归一化因子 1/(Ti+1)1/(T_i + 1)1/(Ti+1) 导致 <think> 标记相对于 <short> 标记的优化不足。

这种不平衡可能导致训练初期严重的模式崩溃。为了解决这些不平衡,作者提出了一种解耦的 GRPO 变体,记为 JDeGRPOJ_{\text{DeGRPO}}JDeGRPO,分别归一化控制标记和回应标记的贡献:
JDeGRPO(θ)=Ex,ai[1G∑i=1G(αLi,0(θ)(Control Tokens)+1Ti∑t=1TiLi,t(θ)(Response Tokens)−βDKL[πθ(⋅∣x)∥πref(⋅∣x)])]J_{\text{DeGRPO}}(\theta) = \mathbb{E}_{x, a_i} \left[ \frac{1}{G} \sum_{i=1}^G \left( \alpha L_{i,0}(\theta) \quad \text{(Control Tokens)} + \frac{1}{T_i} \sum_{t=1}^{T_i} L_{i,t}(\theta) \quad \text{(Response Tokens)} - \beta D_{\text{KL}} \left[ \pi_\theta(\cdot | x) \parallel \pi_{\text{ref}}(\cdot | x) \right] \right) \right] JDeGRPO(θ)=Ex,ai[G1i=1G(αLi,0(θ)(Control Tokens)+Ti1t=1TiLi,t(θ)(Response Tokens)βDKL[πθ(x)πref(x)])]
在 DeGRPO 中,模式选择Li,0(θ)L_{i,0}(\theta)Li,0(θ) 和回应准确性改进 ∑t=1TiLi,t(θ)\sum_{t=1}^{T_i} L_{i,t}(\theta)t=1TiLi,t(θ) 独立归一化。引入了一个与长度无关的权重系数 α\alphaα 以平衡模式选择和回应生成的优化。这种公式确保控制标记在短序列和长序列中都获得一致的梯度规模,从而解决模式-模式和长短不平衡,实现更稳定的推理模式选择优化。如实验所示,适当大的 α\alphaα 可以使模式更新更高效。在实验中,设置 α=1/1000\alpha = 1/1000α=1/1000 以实现稳定的训练。

方法总结

总之,保留了标准 GRPO 框架的整体结构。对于每个查询,从当前策略中抽取一个 mini-batch 样本以估计标记级别的优势。为了解决模式选择和回应生成之间的不平衡,我们独立归一化与控制标记和回应标记相关的优势。这种分离允许在优化过程中显式平衡它们的贡献,从而实现更稳定和有效的训练。

http://www.dtcms.com/a/278830.html

相关文章:

  • Foundry 私钥管理指南:方法与安全最佳实践
  • 《大数据技术原理与应用》实验报告一 熟悉常用的Linux操作和Hadoop操作
  • PHP password_hash() 函数
  • Fiddler——抓取https接口配置
  • 【解决办法】越疆Dobot CR5 桌面客户端DobotStudio Pro连不上机器人
  • 在Ubuntu系统下使用mpstat工具监控CPU性能
  • 深地之下的智慧触角:Deepoc具身智能如何为矿业机器人铸就“感知之核”
  • CSS3 粘性定位解析:position sticky
  • Go从入门到精通(23) - 一个简单web项目-使用数据库存储数据
  • 解决chrome v2 版本插件不支持
  • 上下文管理器 和 contextlib 模块
  • [硬件电路-22]: 为什么模拟电路信号处理运算的精度不如数字信号处理运算?
  • 《Llava:Visual Instruction Tuning》论文精读笔记
  • 基于Chinese-CLIP与ChromaDB的中文图像检索功能实现
  • 人工智能如何重构能源系统以应对气候变化?
  • 动态规划题解——单词拆分【LeetCode】
  • openEuler系统PCIE降速方法简介
  • 【2025/07/14】GitHub 今日热门项目
  • Self - RAG工作步骤
  • 【HTML】五子棋(精美版)
  • 【Java EE】多线程-初阶 认识线程(Thread)
  • 【C语言进阶】指针面试题详解(2)
  • 面试 | JS 面试题 整理(更ing)2/34
  • Android 16系统源码_窗口动画(二)窗口显示动画源码调用流程
  • 护照阅读器:国外证件识别的 OCR “解码师”
  • Python 中调用阿里云 OCR(Optical Character Recognition,光学字符识别)服务
  • STM32介绍和GPIO
  • stm32-Modbus主机移植程序理解以及实战
  • argus/nvarguscamerasrc 远程显示报错
  • 项目一第一天