【论文阅读】Think Only When You Need with Large Hybrid-Reasoning Models
Think Only When You Need with Large Hybrid-Reasoning Models一文指出,当前的大型推理模型(LRMs)通过生成冗长的思维过程(如标记为 <think> 的中间步骤)显著提升了推理能力,但这种方式在处理简单查询时会带来不必要的计算开销和延迟。为解决这一问题,作者提出了大型混合推理模型(LHRMs),这是第一种能够根据用户查询的上下文信息自适应决定是否进行深入思考的模型。
为实现这一目标,作者设计了一个两阶段的训练流程:
-
混合微调(HFT):作为冷启动阶段,通过结合推理密集型(Thinking)和直接回答(No-Thinking)数据,使模型初步支持两种推理模式。
-
混合组策略优化(HGPO):一种在线强化学习方法,通过隐式学习选择适当的思考模式,同时生成更有用且无害的响应。
此外,作者提出了“混合准确率”(Hybrid Accuracy)这一新指标,用于定量评估模型的混合推理能力。实验结果表明,LHRMs能够根据查询的难度和类型自适应地选择思考模式,在推理和通用任务上均优于现有的LRMs和LLMs,同时显著提升了效率。
本篇博客聚焦文章的方法部分。
2 Large Hybrid-Reasoning Models
2.1 Problem Formulation
本节正式定义了大型混合推理模型(LHRMs)的核心问题,即如何根据输入查询动态选择最优推理模式(Thinking或No-Thinking)以最大化任务特定效用。
关键定义与目标
-
输入与模式:
- 输入查询记为 qqq
- 提供两种推理模式:
- 思考模式(⊢\vdash⊢):生成显式推理步骤(如中间计算或逻辑链)
- 非思考模式(⊀\nprec⊀):直接生成最终答案无需中间步骤
-
条件分布:
- 每种模式对应一个答案空间 A\mathcal{A}A 上的条件概率分布:
P(a∣q,m),m∈M={⊢,⊀}(1)\mathcal{P}(a \mid q, m), \quad m \in \mathcal{M} = \{\vdash, \nprec\} \quad (1) P(a∣q,m),m∈M={⊢,⊀}(1)
- 每种模式对应一个答案空间 A\mathcal{A}A 上的条件概率分布:
-
最优模式选择:
- 对每个查询 qqq,选择能最大化期望效用 U(q,a)\mathcal{U}(q,a)U(q,a) 的模式 m∗(q)m^*(q)m∗(q):
m∗(q)=argmaxm∈MEa∼P(a∣q,m)[U(q,a)](2)m^*(q) = \arg\max_{m\in\mathcal{M}} \mathbb{E}_{a\sim\mathcal{P}(a|q,m)}\Big[\mathcal{U}(q,a)\Big] \quad (2) m∗(q)=argm∈MmaxEa∼P(a∣q,m)[U(q,a)](2)
- 对每个查询 qqq,选择能最大化期望效用 U(q,a)\mathcal{U}(q,a)U(q,a) 的模式 m∗(q)m^*(q)m∗(q):
-
全局优化目标:
- 学习策略 π:Q→M\pi: \mathcal{Q}\rightarrow\mathcal{M}π:Q→M 以最大化跨任务分布的期望效用:
maxπ1N∑i=1NEDi∼Θ,Di⇔Ui[Ea∼P(a∣q,π(q)),q∼Di[Ui(q,a)]](3)\max_{\pi} \frac{1}{N}\sum_{i=1}^N \mathbb{E}_{\mathcal{D}_i\sim\Theta, \mathcal{D}_i\Leftrightarrow\mathcal{U}_i}\Bigg[\mathbb{E}_{a\sim\mathcal{P}(a|q,\pi(q)), q\sim\mathcal{D}_i}\Big[\mathcal{U}_i(q,a)\Big]\Bigg] \quad (3) πmaxN1i=1∑NEDi∼Θ,Di⇔Ui[Ea∼P(a∣q,π(q)),q∼Di[Ui(q,a)]](3)
其中 Θ={(Di,Ui)}i=1N\Theta = \{(\mathcal{D}_i,\mathcal{U}_i)\}_{i=1}^NΘ={(Di,Ui)}i=1N 表示不同任务的数据分布和效用函数对。
- 学习策略 π:Q→M\pi: \mathcal{Q}\rightarrow\mathcal{M}π:Q→M 以最大化跨任务分布的期望效用:
核心挑战与解决方案
-
策略学习(C1):
- 通过两阶段训练实现:
- 阶段I:混合微调(HFT)冷启动
- 阶段II:混合组策略优化(HGPO)强化学习
- 通过两阶段训练实现:
-
评估指标(C2):
- 提出混合准确率 Hacc\mathcal{H}_{\text{acc}}Hacc 量化模式选择能力
2.2 第一阶段:混合微调(Hybrid Fine-Tuning, HFT)
本节详细介绍了LHRMs训练流程的第一阶段——混合微调(HFT),这是模型冷启动的关键步骤。
核心设计
数据构建
HFT使用混合格式的监督微调数据集,包含两类数据:
-
思考模式数据:
- 来源:数学(MATH)、编程(Code)和科学领域的高质量数据集
- 处理方式:
- 使用DeepSeek-R1生成答案
- 人工验证正确性
- 添加
<think>
和</think>
标签标记推理步骤 - 示例:
<think> 首先分析约束条件...然后推导可能的解... </think> 最终答案是$\boxed{17}$
-
非思考模式数据:
- 来源:WildChat-1M中的简单查询
- 处理方式:
- 使用FastText分类器过滤复杂推理任务
- 添加
<no_think>
和</no_think>
标签 - 示例:
<no_think> 当然,请问您需要什么帮助? </no_think>
数据集统计
类别 | 数据量 | 平均token长度 | 主要来源 |
---|---|---|---|
思考模式 | 631,325 | 575 | SYNTHETIC-1, OpenMath |
非思考模式 | 674,908 | 4,897 | WildChat-1M, OASST2 |
总计 | 1,694,586 | - | - |
优化目标(Optimize Objective)
HFT阶段通过标准的语言建模目标训练模型,使其能够基于上文预测下一个token。对于构建的数据集DHFT={(xi,yi)}i=1N\mathcal{D}_{\text{HFT}} = \{(x^i, y^i)\}_{i=1}^NDHFT={(xi,yi)}i=1N,其优化目标定义为:
LHFT(θ)=−E(x,y)∼DHFT[∑t=1∣y∣logπθ(yt∣x,y1:t−1)](4)\mathcal{L}_{\text{HFT}}(\theta) = -\mathbb{E}_{(x,y)\sim\mathcal{D}_{\text{HFT}}} \left[ \sum_{t=1}^{|y|} \log \pi_\theta(y_t \mid x, y_{1:t-1}) \right] \quad (4) LHFT(θ)=−E(x,y)∼DHFTt=1∑∣y∣logπθ(yt∣x,y1:t−1)(4)
其中:
- θ\thetaθ:模型参数
- (x,y)(x,y)(x,y):输入-输出对
- πθ\pi_\thetaπθ:模型参数化的概率分布
关键技术点
-
防模式崩溃设计:
- 对同一查询同时提供两种格式的答案
- 示例:
# 思考模式 "计算2+2": "<think>2加2等于4</think>"# 非思考模式 "计算2+2": "<no_think>4</no_think>"
-
数据平衡策略:
- 思考模式与非思考模式样本比例 ≈ 1:1
- 每个batch内两种模式均匀混合
-
训练配置:
- 优化器:AdamW(lr=1e-4)
- 批次大小:128
- 序列长度:32k tokens
- 训练时长:7B模型约2.5天(4×NVIDIA H100节点)
阶段输出
HFT阶段产出的模型πθHFT\pi_{\theta_{\text{HFT}}}πθHFT具备:
- 同时支持两种推理模式的能力
- 稳定的模式切换基础
- 为第二阶段RL训练提供优质初始化
2.3 第二阶段:混合组策略优化(Hybrid Group Policy Optimization, HGPO)
本节详细介绍训练流程的第二阶段——混合组策略优化(HGPO),这是一种创新的强化学习算法,用于优化模型的自适应推理能力。
HGPO的完整流程如图2和算法1所示,通过以下创新设计降低计算成本:
无Critic模型架构
-
核心设计:
- 摒弃传统强化学习中的critic(价值函数)模型
- 采用多样本估计替代价值函数计算
-
采样机制:
- 对提示集P\mathcal{P}P中的每个问题qqq
- 从旧策略πθHFT\pi_{\theta_{\text{HFT}}}πθHFT中采样两组输出:
- 思考模式组:N/2N/2N/2个含推理过程的响应
- 非思考模式组:N/2N/2N/2个直接答案
计算优化特性
设计选择 | 传统RL | HGPO | 优势 |
---|---|---|---|
价值估计 | Critic模型预测 | 多样本直接统计 | 减少40%训练内存 |
梯度计算 | 依赖价值函数导数 | 零阶策略梯度 | 避免梯度冲突问题 |
模式切换成本 | 需要重训练critic | 动态样本重加权 | 支持在线模式切换 |
算法框架
采样策略(Sampling Strategy)
对于每个查询q∈Pq \in \mathcal{P}q∈P,从初始策略πθHFT\pi_{\theta_{\text{HFT}}}πθHFT中按两种模式分别采样N/2N/2N/2个候选响应:
{oi⊢}i=1N/2∼πθHFT(⋅∣q,m=⊢),{oi⊀}i=1N/2∼πθHFT(⋅∣q,m=⊀)(5)\{o_i^\vdash\}_{i=1}^{N/2} \sim \pi_{\theta_{\text{HFT}}}(\cdot \mid q, m=\vdash), \quad \{o_i^\nprec\}_{i=1}^{N/2} \sim \pi_{\theta_{\text{HFT}}}(\cdot \mid q, m=\nprec) \quad (5) {oi⊢}i=1N/2∼πθHFT(⋅∣q,m=⊢),{oi⊀}i=1N/2∼πθHFT(⋅∣q,m=⊀)(5)
完整候选集定义为:
O(q)={oi⊢}i=1N/2∪{oi⊀}i=1N/2(6)\mathcal{O}(q) = \{o_i^\vdash\}_{i=1}^{N/2} \cup \{o_i^\nprec\}_{i=1}^{N/2} \quad (6) O(q)={oi⊢}i=1N/2∪{oi⊀}i=1N/2(6)
实现细节:
- 默认N=4N=4N=4(每种模式2个样本)
- 温度系数τ=0.7\tau=0.7τ=0.7控制多样性
- 禁止重复采样机制
奖励计算与分配(Reward Scoring and Assignment)
使用奖励函数RϕR_\phiRϕ对候选输出评分,生成两组奖励值:
R⊢={r(oi⊢)}i=1N/2,R⊀={r(oi⊀)}i=1N/2(7)\mathcal{R}^\vdash = \{r(o_i^\vdash)\}_{i=1}^{N/2}, \quad \mathcal{R}^\nprec = \{r(o_i^\nprec)\}_{i=1}^{N/2} \quad (7) R⊢={r(oi⊢)}i=1N/2,R⊀={r(oi⊀)}i=1N/2(7)
计算各模式平均奖励:
Rˉ⊢=2N∑i=1N/2r(oi⊢),Rˉ⊀=2N∑i=1N/2r(oi⊀)(8)\bar{\mathcal{R}}^\vdash = \frac{2}{N}\sum_{i=1}^{N/2} r(o_i^\vdash), \quad \bar{\mathcal{R}}^\nprec = \frac{2}{N}\sum_{i=1}^{N/2} r(o_i^\nprec) \quad (8) Rˉ⊢=N2i=1∑N/2r(oi⊢),Rˉ⊀=N2i=1∑N/2r(oi⊀)(8)
定义两种奖励类型:
- 组间奖励(Inter-group):
rinter(oim)={1,if m=argmaxm′∈{⊢,⊀}{Rˉ⊢,Rˉ⊀+δ}0,otherwise(9a)r_{\text{inter}}(o_i^m) = \begin{cases} 1, & \text{if } m = \arg\max_{m'\in\{\vdash,\nprec\}} \{\bar{\mathcal{R}}^\vdash, \bar{\mathcal{R}}^\nprec + \delta\} \\ 0, & \text{otherwise} \end{cases} \quad (9a) rinter(oim)={1,0,if m=argmaxm′∈{⊢,⊀}{Rˉ⊢,Rˉ⊀+δ}otherwise(9a) - 组内奖励(Intra-group):
rintra(oim)={1,if i=argmaxj∈{1,...,N/2}rjm0,otherwise(9b)r_{\text{intra}}(o_i^m) = \begin{cases} 1, & \text{if } i = \arg\max_{j\in\{1,...,N/2\}} r_j^m \\ 0, & \text{otherwise} \end{cases} \quad (9b) rintra(oim)={1,0,if i=argmaxj∈{1,...,N/2}rjmotherwise(9b)
关键参数:
- δ\deltaδ:模式偏好边际(默认0.2)
- 规则型奖励用于数学/编程等确定性任务
- 参数化奖励模型用于开放域任务
δ\deltaδ这个参数的出现提供了一种可以控制模型思考偏好的方法,在具体工程实现中,可以基于任务种类设置不同的δ\deltaδ达到控制长短的目的
优势估计(Advantage Estimation)
采用GRPO优势估计器:
Ait=[rintra(oi)−mean(rintra(oj))std(rintra(oj))]⏟Intra-group+1{oit∈Φ}⋅α[rinter(oi)−mean(rinter(oj))std(rinter(oj))]⏟Inter-group(10)A_i^t = \underbrace{\left[\frac{r_{\text{intra}}(o_i) - \text{mean}(r_{\text{intra}}(o_j))}{\text{std}(r_{\text{intra}}(o_j))}\right]}_{\text{Intra-group}} + \underbrace{\mathbb{1}\{o_i^t \in \Phi\} \cdot \alpha \left[\frac{r_{\text{inter}}(o_i) - \text{mean}(r_{\text{inter}}(o_j))}{\text{std}(r_{\text{inter}}(o_j))}\right]}_{\text{Inter-group}} \quad (10) Ait=Intra-group[std(rintra(oj))rintra(oi)−mean(rintra(oj))]+Inter-group1{oit∈Φ}⋅α[std(rinter(oj))rinter(oi)−mean(rinter(oj))](10)
其中:
- Φ={<think>,<no_think>}\Phi = \{\text{<think>}, \text{<no\_think>}\}Φ={<think>,<no_think>}为模式标记集合
- α=1.0\alpha=1.0α=1.0为平衡系数
优化目标(Optimization Objective)
最大化以下目标函数:
JHGPO(θ)=Eq∼P,{oim}∼πθHFT[1N∑i=1N∑t=1∣o∣[min(πθ(oim,t∣q,oim,<t)πθHFT(oim,t∣q,oim,<t)Ait,clip(πθ(oim,t∣q,oim,<t)πθHFT(oim,t∣q,oim,<t),1−ϵ,1+ϵ)Ait)−βDKL(πθ∣∣πref)]](11)\mathcal{J}_{\text{HGPO}}(\theta) = \mathbb{E}_{q\sim\mathcal{P}, \{o_i^m\}\sim\pi_{\theta_{\text{HFT}}}}\Bigg[ \frac{1}{N}\sum_{i=1}^N \sum_{t=1}^{|o|} \bigg[ \min\Bigg( \frac{\pi_\theta(o_i^{m,t}|q,o_i^{m,<t})}{\pi_{\theta_{\text{HFT}}}(o_i^{m,t}|q,o_i^{m,<t})} A_i^t, \\ \text{clip}\Bigg(\frac{\pi_\theta(o_i^{m,t}|q,o_i^{m,<t})}{\pi_{\theta_{\text{HFT}}}(o_i^{m,t}|q,o_i^{m,<t})}, 1-\epsilon, 1+\epsilon\Bigg) A_i^t \bigg) - \beta \mathbb{D}_{\text{KL}}(\pi_\theta || \pi_{\text{ref}}) \bigg] \Bigg] \quad (11) JHGPO(θ)=Eq∼P,{oim}∼πθHFT[N1i=1∑Nt=1∑∣o∣[min(πθHFT(oim,t∣q,oim,<t)πθ(oim,t∣q,oim,<t)Ait,clip(πθHFT(oim,t∣q,oim,<t)πθ(oim,t∣q,oim,<t),1−ϵ,1+ϵ)Ait)−βDKL(πθ∣∣πref)]](11)
KL散度项展开为:
DKL(πθ∣∣πref)=πref(oim∣q)πθ(oim∣q)−logπref(oim∣q)πθ(oim∣q)−1(12)\mathbb{D}_{\text{KL}}(\pi_\theta || \pi_{\text{ref}}) = \frac{\pi_{\text{ref}}(o_i^m|q)}{\pi_\theta(o_i^m|q)} - \log \frac{\pi_{\text{ref}}(o_i^m|q)}{\pi_\theta(o_i^m|q)} - 1 \quad (12) DKL(πθ∣∣πref)=πθ(oim∣q)πref(oim∣q)−logπθ(oim∣q)πref(oim∣q)−1(12)
训练配置:
- 学习率:1×10−61\times10^{-6}1×10−6(恒定)
- 批次大小:256(微批次8)
- KL系数β=0.001\beta=0.001β=0.001
- 裁剪阈值ϵ=0.5\epsilon=0.5ϵ=0.5
- 训练时长:2天(4×H100)
算法特性
-
双重奖励机制:
- 组间奖励引导模式选择
- 组内奖励优化内容质量
-
策略约束:
- KL惩罚项防止过度偏离初始策略
- 重要性采样裁剪保证稳定性
-
零阶优化:
无需价值函数模型,直接基于样本奖励优化
2.4 混合推理能力评估
为更全面地评估LHRMs的性能(超越传统下游任务指标),文章提出新指标混合准确率(Hybrid Accuracy, Hacc\mathcal{H}_{acc}Hacc),用于量化模型选择正确推理模式的能力。
评估流程
给定任务提示集P={pi}i=1K\mathcal{P} = \{p_i\}_{i=1}^KP={pi}i=1K:
- 对每个pip_ipi,模型在⊢\vdash⊢和⊀\nprec⊀模式下各生成NNN个响应
- 使用奖励模型RϕR_\phiRϕ对响应评分,计算各模式平均得分Rˉ⊢\bar{\mathcal{R}}^\vdashRˉ⊢和Rˉ⊀\bar{\mathcal{R}}^\nprecRˉ⊀
- 确定基准模式mgtm_{gt}mgt:
- 若∣Rˉ⊢−Rˉ⊀∣>ϵ|\bar{\mathcal{R}}^\vdash - \bar{\mathcal{R}}^\nprec| > \epsilon∣Rˉ⊢−Rˉ⊀∣>ϵ,选择高分模式
- 否则选择响应更短的模式
- 模型自主选择模式mpm_pmp,计算匹配比例:
Hacc=1K∑i=1K1[Equal(mgt,mp)]s.t.mgt,mp∈{⊢,⊀}(13)\mathcal{H}_{acc} = \frac{1}{K}\sum_{i=1}^K \mathbb{1}\left[\text{Equal}(m_{gt}, m_p)\right] \quad \text{s.t.} \quad m_{gt}, m_p \in \{\vdash, \nprec\} \quad (13) Hacc=K1i=1∑K1[Equal(mgt,mp)]s.t.mgt,mp∈{⊢,⊀}(13)
关键参数:
- ϵ\epsilonϵ:模式得分差异阈值(默认0.05)
- NNN:每种模式采样数(默认4)