DAPO:对GRPO的几点改进
DAPO:对GRPO的几点改进
TL; DR:对 GRPO 的几处细节进行了优化,包括去除 KL 约束、解耦 ppo-clip 的上下界,上界设置更高以鼓励探索、超长回答过滤、token level 损失计算等。相比于原始 GRPO,在 AIME24 上提升非常显著。
方法
依次介绍下 DAPO 中对 GRPO 提出的几点改进。
0. remove kl and rule-based reward
在 RLHF 中,会使用 KL 散度约束当前 policy π θ \pi_\theta πθ 与初始的参考 policy π ref \pi_\text{ref} πref 之间的差异不要太大,避免模型训飞。作者认为,这个约束不是很必要,尤其是在训练 Long CoT 能力时,模型的输出分布就是会跟初始模型有比较大的差异,加这个约束会过于限制 RLHF 训练的提升空间,所以直接去掉了。
在 R1 之后,优先用可验证的 rule-based reward 已经基本是个共识了。使用基于规则的奖励,可以避免 reward model 带来的 reward hacking 的问题。在数学和编程两个推理能力的重点领域,都有可验证的 rule-based reward(数学题答案和代码编译运行结果)。DAPO 主要做的是数学任务,也使用了 rule-based model,奖励计算方式如下,直接对比模型预测 y ^ \hat{y} y^ 与真实结果 y y y 是否相等,相等给 1,不等都给 -1:
R ( y ^ , y ) = { 1 , is_equivalent ( y , y ^ ) − 1 , otherwise R(\hat{y},y)= \begin{aligned} \begin{cases} 1,\quad &\text{is\_equivalent}(y,\hat{y}) \\ -1,\quad &\text{otherwise} \end{cases} \end{aligned} \notag \\ R(y^,y)={1,−1,is_equivalent(y,y^)otherwise
1. clip higher
DAPO 的第一个重要改进是 clip higher。
作者发现,在实际使用 PPO / GRPO 时会出现熵坍塌(entropy collapse)的现象:训练过程中 policy 输出分布的熵快速下降,即对于一个问题,采样出的一个 group 内的回答趋近于相同。这导致 policy 过于确定,exploitation 有余而 exploration 不足,大大限制了强化学习的潜力。
Exploitation v.s. Exploration 是强化学习中的一个重要课题,Exploitation 指利用已有的经验获取尽可能高的累积回报,Exploration 则是指训练时不只关注短期的回报,而是也以一定概率采样其他动作,尽量地探索是否有其他能赢得更高回报的可能选择。
由于强化学习的数据是自己 rollout 得到的,而非人为给定的数据,所以需要自己考虑去 Explore 更多的可能性。
作者认为,这种熵坍塌现象的出现,是由于 PPO-CLIP 中的 upper clip,按照比率 r = π θ π θ old r=\frac{\pi_\theta}{\pi_{\theta_\text{old}}} r=πθoldπθ 限制更新的上限,相当于是打压了低概率 token(exploration token),而鼓励了高概率 token(exploitation token)。
举个例子,PPO-clip 的超参数 ϵ \epsilon ϵ 的典型值为 0.2,我们考虑优势 A ^ t > 0 \hat{A}_t>0 A^t>0,也就是当前动作需要被鼓励(增大该动作被 π θ \pi_\theta πθ 采样出来的概率)的情况。假设有两个动作被 π θ old \pi_{\theta_\text{old}} πθold 采样出来的概率分别是 0.01 和 0.9,那么,由于 clip 操作的限制,我们最大只能把 π θ \pi_\theta πθ 采样出这两个动作的概率分别提升(+20%)到 0.012 和 1.08 。可以看到,由于 clip 限制的是提升的比率,所以原来概率比较低的动作(这些动作可以看作是在做 exploration,所以称为 exploration token)提升的幅度要远远小于概率比较高的动作(这些动作是要最大化累积回报,所以称为 exploitation token),所以说,exploration 被打压了,而 exploitation 更被鼓励。
作者的解决方案是 clip higher,就是把 PPO 中的 lower/upper 两个 clip 超参的设置解耦开,对 upper clip 设置一个比较高的值 ϵ high \epsilon_\text{high} ϵhigh,从而让好的动作有更大的提升空间。而 lower clip 的超参 ϵ low \epsilon_\text{low} ϵlow 还是要保持相对较小的值,因为如果这个值大了,会对 A ^ t < 0 \hat{A}_t<0 A^t<0 动作的打压力度更强,导致这些动作更加不可能被采样到,这样也会使得 explore 空间受限。
有无 clip higher 在 AIME 上性能和采样结果熵的对比如下图所示。可以看到,如果没有采用 clip higher,熵就快速降至接近 0 了,此时(1000steps)左右,性能也基本收敛了。而加上 clip higher,熵能够保持在 0.4 左右,模型仍有一定的 explore 能力,性能也可以持续提升。
有几个问题:
- 个人感觉这个 clip higher 不是很对症啊。照文中分析的这个问题,问题的根源在于重要性采样的这个比率 r r r 是按照比例来约束更新上界的,但是 exploration tokens 本身肯定就比 exploitation tokens 概率要低,因此相对提升空间会更小。那么解决方案应该是要重点扶持一下 low prob positive tokens (exploration tokens)。而像 clip higher 这样直接统一拉高 upper epsilon,high prob tokens (exploitation tokens) 不是反而受益会更大吗?
- 原文图 3 (a) 想说明啥?
2. dynamic sampling
DAPO 中,延用了 GRPO 的做法:不使用 critic model 来计算 GAE,而是通过基于组内全部采样结果 { o i } i = 1 G \{o_i\}_{i=1}^G {oi}i=1G 的均值和标准差来进行归一化,作为 baseline,计算优势函数 A ^ t \hat{A}_t A^t。这样在极端情况下,即整组的结果全对或全错时, A ^ t \hat{A}_t A^t 就为零了,也就没有梯度了。这样显然会影响训练效率,尤其是当训练后期,模型能力变强了,组内采样结果全对的情况会更多,对训练效率的影响更大。而且这样梯度的方差也会很大,对模型训练也是不利的。
解决这个问题的方案很简单,就是在采样时进行过滤,将组内全对或全错的样本直接丢掉,保证组内优势 A ^ t \hat{A}_t A^t 都是非零的。如果 RL 训练框架不是异步的,而且 rollout 过程不是流水线的话,rollout 占用的时间是由长尾的样本决定的,因此这种情况下加这个过滤策略对训练效率的影响也不会太大。
3. token-level policy gradient loss
GRPO 中采用的是 sample level 的损失,即对每个样本的所有 token 先计算平均损失,再算多样本的平均损失。这样相当于是长回答和短回答的权重是相同的,作者认为长回答的权重应该更高,否则对于长回答中的高低质量响应模式无法得到足够强度的鼓励和惩罚。
解决方案就是不再对同一条样本内 token 损失单独进行平均,相当于是对所有 token 做 token level 的 loss。这样相当于对长回答中的高低质量的响应模式的鼓励和惩罚权重就更大了。在 Long CoT 的场景中,比较有用。
4. overlong reward shaping
DAPO 的第四点改进,是对于超长回答的惩罚设计,包括 overlong filtering 和 soft overlong punishment 两点。
一般情况下,我们会直接对超出最大长度的样本进行惩罚。但是作者认为这会带来一些噪声,因为一些超长的回答,本身推理过程的思路可能是 ok 的,只是因为超出最大长度,就得到一个负的 reward,对这样的高质量但是超长的回答进行惩罚,可能会让模型比较困惑。
初步解决方案是对于超长的响应,不算对也不算错,直接 mask 掉超长样本的 loss,不要这些样本了,即 overlong filtering。实验发现,这样可以稳定训练并提升性能。
进一步地,作者还提出了一种超长软惩罚(soft overlong punishment)。具体来说就是对于超出预定最大长度的样本,定义一个惩罚区间,在区间内,回答越长,惩罚越大:
R length ( y ) = { 0 , if ∣ y ∣ ≤ L max − L cache ( L max − L cache ) − ∣ y ∣ L cache , if L max − L cache < ∣ y ∣ ≤ L max − 1 , if L max < ∣ y ∣ R_{\text{length}}(y) = \begin{cases} 0, & \text{if } |y| \leq L_{\text{max}} - L_{\text{cache}} \\ \frac{(L_{\text{max}} - L_{\text{cache}}) - |y|}{L_{\text{cache}}}, & \text{if } L_{\text{max}} - L_{\text{cache}} < |y| \leq L_{\text{max}} \\ -1, & \text{if } L_{\text{max}} < |y| \end{cases} \notag \\ Rlength(y)=⎩ ⎨ ⎧0,Lcache(Lmax−Lcache)−∣y∣,−1,if ∣y∣≤Lmax−Lcacheif Lmax−Lcache<∣y∣≤Lmaxif Lmax<∣y∣
put it togther
最终,DAPO 整体的优化目标如下式,式中不同颜色表示了 DAPO 提出的几点改进。
J DAPO ( θ ) = E q , a [ 1 ∑ i = 1 G ∣ o i ∣ ∑ i = 1 G ∑ t = 1 ∣ o i ∣ min ( r i , t ( θ ) A ^ i , t , clip ( r i , t ( θ ) , 1 − ϵ low , 1 + ϵ high ) A ^ i , t ) ] s.t. 0 < ∣ { o i ∣ is_equivalent ( a , o i ) } ∣ < G 其中 r i , t ( θ ) = π θ ( o i , t ∣ q , o i , < t ) π θ old ( o i , t ∣ q , o i , < t ) , A ^ i , t = R i − mean ( { R i } i = 1 G ) std ( { R i } i = 1 G ) \mathcal{J}_\text{DAPO}(\theta)=\mathbb{E}_{q,a}\left[\frac{1}{\textcolor{green}{\sum_{i=1}^G|o_i|}}\textcolor{green}{\sum_{i=1}^G\sum_{t=1}^{|o_i|}}\min\left(r_{i,t}(\theta)\hat{A}_{i,t},\text{clip}\left(r_{i,t}(\theta),1-\textcolor{red}{\epsilon_\text{low}},\ 1+\textcolor{red}{\epsilon_\text{high}}\right)\hat{A}_{i,t}\right)\right] \\ \text{s.t.}\textcolor{blue}{\quad 0<|\{o_i|\text{is\_equivalent}(a,o_i)\}|<G} \\ 其中\quad r_{i,t}(\theta)=\frac{\pi_\theta(o_{i,t}|q,o_{i,<t})}{\pi_{\theta_\text{old}}(o_{i,t}|q,o_{i,<t})},\quad \hat{A}_{i,t}=\frac{R_i-\text{mean}(\{R_i\}_{i=1}^G)}{\text{std}(\{R_i\}_{i=1}^G)} \notag \\ JDAPO(θ)=Eq,a ∑i=1G∣oi∣1i=1∑Gt=1∑∣oi∣min(ri,t(θ)A^i,t,clip(ri,t(θ),1−ϵlow, 1+ϵhigh)A^i,t) s.t.0<∣{oi∣is_equivalent(a,oi)}∣<G其中ri,t(θ)=πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t),A^i,t=std({Ri}i=1G)Ri−mean({Ri}i=1G)
整体的算法流程如下所示,在各项改进都介绍过之后,应该比较清楚了:
实验
这类偏实践性质的技术报告,最有价值的部分就是实验,尤其是消融实验。DAPO 在实验部分也是上来就贴出了消融的结果,展示上述各种设计对性能提升的贡献。可以看到,使用最原始的 GRPO,在 AIME24 (avg@32) 上只能达到 30 的精度,远不及 R1 报告的 47,在应用 DAPO 提出的几点改进之后,一路涨到了 50,从这个结果看,提升还是非常显著的。尤其是 overlong filtering 和 dynamic sampling 两项,各自单独就有 6-8 个点的提升。
总结
DAPO 对标准 GRPO 的一些细节进行了改进优化,从实验结果来看,提升非常显著。字节能开放自己在 RLHF 领域的探索结果出来肯定是非常好的。美中不足是感觉报告的写作比较潦草,并且实验略显单薄,以及对于 clip higher 这个解决方案个人感觉不是很对症。当然也可能是我自己理解不到位 😃,期望有大佬能指点下我的困惑。