【AI Infra】【RLHF框架】四、VeRL中PPO、GRPO、REINFORCE++、RLOO实现源码解析
系列文章:
【AI Infra】【RLHF框架】一、VeRL中基于Ray的执行流程源码解析
【AI Infra】【RLHF框架】二、VeRL中colocate实现解析
【AI Infra】【RLHF框架】三、VeRL中的Rollout实现源码解析
【AI Infra】【RLHF框架】四、VeRL中PPO、GRPO、REINFORCE++、RLOO实现源码解析
相比于前三篇博客,本偏博客涉及的公式较多,但对于理解RL框架这也是不可避免的。
本篇博客尽量采用统一的符号来介绍PPO、GRPO、REINFORCE++和RLOO的理论,这样便于理解各类算法之间的关联。此外,很多时候实现代码为了简洁或者数值稳定等原因,不一定完全按照公式进行实现。本文也尽量补齐理论和实现的差距。
希望这篇博客对你有所帮助,如有错误,欢迎指正。
一、LLM场景下的PPO
1. PPO原理
自回归语言模型
π
θ
\pi_{\theta}
πθ将提示
x
x
x作为输入,然后通过自回归的方式逐步产生输出
y
y
y。若采用贪心解码的方式,则生成
y
y
y中第
t
t
t个token的过程为
y
t
=
arg
max
y
π
θ
(
y
∣
x
,
y
<
t
)
(1)
y_t = \arg\max_y\pi_{\theta}(y|x,y_{<t}) \tag{1}\\
yt=argymaxπθ(y∣x,y<t)(1)
若将这个过程建模为马尔科夫决策过程(MDP),则
s
t
=
x
⊕
y
<
t
s_t=x\oplus y_{<t}
st=x⊕y<t,
a
t
a_t
at表示进行第
t
t
t个token的选择。单个样本的生成过程构成了一条完整的轨迹
τ
=
(
s
0
,
a
t
,
s
1
,
a
1
,
.
.
.
)
\tau=(s_0,a_t,s_1,a_1,...)
τ=(s0,at,s1,a1,...)。
基于上面的定义,LLM场景下PPO损失函数定义为
L
PPO
(
θ
)
=
−
E
x
∼
p
,
y
∼
π
(
⋅
∣
x
)
1
∣
y
∣
∑
t
=
1
∣
y
∣
min
[
r
t
(
θ
)
A
t
,
clip
(
r
t
(
θ
)
,
1
−
ε
,
1
+
ε
)
A
t
]
(2)
L_{\text{PPO}}(\theta)=-\text{E}_{x\sim p,y\sim\pi(\cdot|x)}\frac{1}{|y|}\sum_{t=1}^{|y|}\min\Big[r_t(\theta)A_t,\text{clip}(r_t(\theta),1-\varepsilon,1+\varepsilon)A_t\Big] \tag{2}\\
LPPO(θ)=−Ex∼p,y∼π(⋅∣x)∣y∣1t=1∑∣y∣min[rt(θ)At,clip(rt(θ),1−ε,1+ε)At](2)
- A t A_t At是优势函数,表示采取动作 a t a_t at的预期收益,该值的计算会在后续GAE小节中详细介绍。
- r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st)是重要性采样,通过重要性采样能够调整旧策略 π θ old \pi_{\theta_{\text{old}}} πθold样本的权重来适应新策略 π θ \pi_{\theta} πθ的分布。例如,旧策略以高概率选择了某个动作 a a a,但新策略以低概率选择该动作,那么 r t ( θ ) r_t(\theta) rt(θ)会较小,从而降低该样本对整个梯度的影响。反之,权重会放大其贡献。
- clip \text{clip} clip是截断函数,用于防止出现大幅度更新,导致模型退化。
在verl中该损失函数的计算如下:
# verl/trainer/ppo/core_algos.py
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
"""
old_log_prob: (bs, response_length)
log_prob: (bs, response_length)
advantages: (bs, response_length)
eos_mask: (bs, response_length)
"""
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl
2. 广义优势估计(GAE)
在PPO原理介绍中提及了优势函数 A t A_t At,本小节介绍估计 A t A_t At的常用方式GAE。
2.1 优化函数的定义
A t = Q ( s t , a t ) − V ( s t ) (3) A_t=Q(s_t,a_t)-V(s_t) \tag{3}\\ At=Q(st,at)−V(st)(3)
其含义是在当前状态 s t s_t st下采取动作 a t a_t at的未来预期收益 Q ( s t , a t ) Q(s_t,a_t) Q(st,at)与状态 s t s_t st下平均动作预期收益的差值,即衡量动作 a t a_t at比平均动作好多少。在LLM的RL场景中, A t A_t At表示在第 t t t个token选择能获得的收益比平均选择所有token高多少。
2.2 TD估计
单步TD估计。由于
Q
Q
Q和
V
V
V这两函数都是未知的,理论上需要两个深度模型来分别拟合。但是,通过贝尔曼方程将
Q
Q
Q转换为
V
V
V,即
A
t
1
=
r
t
+
γ
V
(
s
t
+
1
)
−
V
(
s
t
)
A_t^{1}=r_t+\gamma V(s_{t+1}) -V(s_t) \\
At1=rt+γV(st+1)−V(st)
其中
γ
\gamma
γ是折扣因子,
r
t
r_t
rt是第
t
t
t步的实际奖励。这样就能使用单个神经网络来估计优势值,其中拟合
V
V
V的这个网络就是Critic
模型。这种包含了一步真实奖励
r
t
r_t
rt的估计称为优势函数的单步TD估计。但是,单步TD存在低方差但偏差高的问题。因为,仅依赖一步真实奖励的方差较低,而更多使用Critic
模型的预测值则可能导致高偏差(模型预测值往往波动较小,但可能存在整体预测偏高或偏低的问题)。
n步TD估计。为了方便表示,令
δ
t
=
r
t
+
γ
V
(
s
t
+
1
)
−
V
(
s
t
)
\delta_t=r_t+\gamma V(s_{t+1})-V(s_t)
δt=rt+γV(st+1)−V(st)。那么单步TD估计表示为
A
t
1
=
δ
t
A_t^1=\delta_t \\
At1=δt
为了降低单步TD估计中的偏差,可以考虑估计中使用更多的真实奖励,那么通过贝尔曼方程将单步TD中的
V
(
s
t
+
1
)
V(s_{t+1})
V(st+1)展开,就得到2步TD估计
A
t
2
=
r
t
+
γ
(
r
t
+
1
+
γ
V
(
s
t
+
2
)
)
−
V
(
s
t
)
=
r
t
+
γ
V
(
s
t
+
1
)
−
V
(
s
t
)
+
γ
(
r
t
+
1
+
γ
V
(
s
t
+
2
)
−
V
(
s
t
+
1
)
)
=
δ
t
+
γ
δ
t
+
1
\begin{align} A_t^2&=r_t+\gamma(r_{t+1}+\gamma V(s_{t+2})) - V(s_t) \\ &=r_t+\gamma V(s_{t+1})-V(s_t)+\gamma(r_{t+1}+\gamma V(s_{t+2})-V(s_{t+1})) \\ &=\delta_t+\gamma\delta_{t+1} \\ \end{align} \\
At2=rt+γ(rt+1+γV(st+2))−V(st)=rt+γV(st+1)−V(st)+γ(rt+1+γV(st+2)−V(st+1))=δt+γδt+1
以此类推,n步的TD估计为
A
t
n
=
∑
l
=
0
n
−
1
γ
l
δ
t
+
l
(4)
A_t^n=\sum_{l=0}^{n-1}\gamma^l\delta_{t+l} \tag{4}\\
Atn=l=0∑n−1γlδt+l(4)
2.3 广义优势估计(GAE)原理
n步TD估计中n越大则使用的真实奖励越多,对应的方差就越高,偏差越低。由于不同任务对应的最优n是不同的,很难精确选择合适的n,GAE中则选择融合所有可能的步长(
n
=
1
,
2
,
…
,
∞
n=1,2,\dots,\infty
n=1,2,…,∞)。此外,引入一个衰减因子
λ
∈
[
0
,
1
]
\lambda\in[0,1]
λ∈[0,1]来对远期步长施加衰减,即
A
t
GAE
=
∑
l
=
0
∞
(
γ
λ
)
l
δ
t
+
1
(5)
A_t^{\text{GAE}}=\sum_{l=0}^{\infty}(\gamma\lambda)^l\delta_{t+1} \tag{5}\\
AtGAE=l=0∑∞(γλ)lδt+1(5)
当
λ
=
0
\lambda=0
λ=0时,GAE估计退化为
δ
t
\delta_t
δt;当
λ
=
1
\lambda=1
λ=1时,GAE等价于蒙特卡洛估计(全部使用真实奖励)。所以,
λ
\lambda
λ可以当做是调节方差和偏差的超参数,
λ
\lambda
λ越大则真实奖励越多,方差越大,偏差越小。
2.4 GAE的计算
GAE的计算。在实际计算中,通常需要计算所有步骤的GAE,为了避免重复计算可以采用迭代的方式从轨迹末端开始计算。
A
t
GAE
=
δ
t
+
(
γ
λ
)
δ
t
+
1
+
(
γ
λ
)
2
δ
t
+
2
+
.
.
.
=
δ
t
+
γ
λ
(
δ
t
+
1
+
(
γ
λ
)
δ
t
+
2
+
…
)
=
δ
t
+
γ
λ
A
t
+
1
GAE
\begin{align} A_t^{\text{GAE}}&=\delta_t+(\gamma\lambda)\delta_{t+1}+(\gamma\lambda)^2\delta_{t+2}+... \\ &=\delta_t+\gamma\lambda(\delta_{t+1}+(\gamma\lambda)\delta_{t+2}+\dots) \\ &=\delta_t+\gamma\lambda A_{t+1}^{\text{GAE}} \end{align} \\
AtGAE=δt+(γλ)δt+1+(γλ)2δt+2+...=δt+γλ(δt+1+(γλ)δt+2+…)=δt+γλAt+1GAE
通过上面的公式可以方便的迭代计算GAE。下面是verl中的实现代码:
# verl/trainer/ppo/core_algos.py
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor, lam: torch.Tensor):
"""
token_level_rewards: (bs, response_length)
values: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
return advantages, returns
代码中优势函数的计算与上面迭代公式完全一致。此外,这里还计算了回报returns
。回忆一下优势的定义为
A
t
=
Q
(
s
t
,
a
t
)
−
V
(
s
t
)
A_t=Q(s_t,a_t)-V(s_t)
At=Q(st,at)−V(st),那么
Q
(
s
t
,
a
t
)
=
A
t
+
V
(
s
t
)
Q(s_t,a_t)=A_t+V(s_t)
Q(st,at)=At+V(st)。
Q
(
s
t
,
a
t
)
Q(s_t,a_t)
Q(st,at)的定义就是状态
s
t
s_t
st下动作
a
t
a_t
at的折扣回报。
3. KL约束以及近似计算
3.1 KL约束
损失函数
L
PPO
(
θ
)
L_{\text{PPO}}(\theta)
LPPO(θ)中主要是最大化预期收益,但是在LLM场景中仅最大化预期收益可能导致生成的response不符合自然语言的表达。为了解决这个问题引入了KL约束,使得优化后的模型不要太偏离ref(sft)模型。那么损失函数变为
L
PPO
(
θ
)
=
−
E
x
∼
p
,
y
∼
π
(
⋅
∣
x
)
1
∣
y
∣
{
∑
t
=
1
∣
y
∣
min
[
r
t
(
θ
)
A
t
,
clip
(
r
t
(
θ
)
,
1
−
ε
,
1
+
ε
)
A
t
]
−
β
KL
[
π
θ
old
(
a
t
∣
s
t
)
,
π
θ
ref
(
a
t
∣
s
t
)
]
}
(6)
L_{\text{PPO}}(\theta)=-\text{E}_{x\sim p,y\sim\pi(\cdot|x)}\frac{1}{|y|}\Big\{\sum_{t=1}^{|y|}\min[r_t(\theta)A_t,\text{clip}(r_t(\theta),1-\varepsilon,1+\varepsilon)A_t]\\ -\beta\text{KL}[\pi_{\theta_{\text{old}}}(a_t|s_t),\pi_{\theta_{\text{ref}}}(a_t|s_t)]\Big\} \tag{6}\\
LPPO(θ)=−Ex∼p,y∼π(⋅∣x)∣y∣1{t=1∑∣y∣min[rt(θ)At,clip(rt(θ),1−ε,1+ε)At]−βKL[πθold(at∣st),πθref(at∣st)]}(6)
其中
KL
[
q
,
p
]
=
∑
x
q
(
x
)
log
q
(
x
)
p
(
x
)
=
E
x
∼
q
[
log
q
(
x
)
p
(
x
)
]
(7)
\text{KL}[q,p]=\sum_x q(x)\log\frac{q(x)}{p(x)}=E_{x\sim q}\Big[\log\frac{q(x)}{p(x)}\Big] \tag{7}\\
KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)](7)
3.2 近似计算
在计算KL散度时需要计算两个模型在整个词表上的概率,为了简化计算过程,在实现时通常采用近似计算。
k1估计器。一种朴素的近似方式是使用无偏估计
k
1
=
log
q
(
x
)
p
(
x
)
(8)
k1 = \log\frac{q(x)}{p(x)} \tag{8}\\
k1=logp(x)q(x)(8)
但这个估计器是高方差的。因为只要
q
(
x
)
<
p
(
x
)
q(x)<p(x)
q(x)<p(x),
k
1
k1
k1就是负值,但KL散度应该是非负的。
k2估计器。为了解决
k
1
k1
k1的高方差问题,可以使用
k
2
k2
k2估计器
k
2
=
1
2
(
log
p
(
x
)
q
(
x
)
)
2
(9)
k2=\frac{1}{2}(\log\frac{p(x)}{q(x)})^2 \tag{9}\\
k2=21(logq(x)p(x))2(9)
显然,
k
2
k2
k2估计器是有偏的,但在实际使用中偏差较小且能有效降低方差。
k3估计器。因为k1是无偏的,那么降低方差可以通过添加一个期望为0但与k1负相关的项。
p
(
x
)
q
(
x
)
−
1
\frac{p(x)}{q(x)}-1
q(x)p(x)−1的期望为0,因为
E
x
∼
q
[
p
(
x
)
q
(
x
)
−
1
]
=
E
x
∼
q
[
p
(
x
)
q
(
x
)
]
−
1
=
0
\text{E}_{x\sim q}\Big[\frac{p(x)}{q(x)}-1\Big]=\text{E}_{x\sim q}\Big[\frac{p(x)}{q(x)}\Big] -1 = 0
Ex∼q[q(x)p(x)−1]=Ex∼q[q(x)p(x)]−1=0。因此,k3估计器定义为
k
3
=
p
(
x
)
q
(
x
)
−
1
−
log
p
(
x
)
q
(
x
)
(10)
k3=\frac{p(x)}{q(x)}-1-\log\frac{p(x)}{q(x)} \tag{10}\\
k3=q(x)p(x)−1−logq(x)p(x)(10)
# verl/trainer/ppo/core_algos.py
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
if kl_penalty == "kl":
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
if kl_penalty == 'low_var_kl':
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
kl
、mse
、low_var_kl
分别对应
k
1
k1
k1、
k
2
k2
k2和
k
3
k3
k3,abs
则是通过绝对值的方式解决
k
1
k1
k1为负的问题。
4. 熵正则
在PPO训练的过程,有可能过早将概率集中在某个特定的token上。这意味着过早放弃了探索,容易陷入局部最优。熵正则就是通过在损失函数中添加概率分布
π
θ
(
a
t
∣
s
t
)
\pi_{\theta}(a_t|s_t)
πθ(at∣st)的信息熵,从而鼓励学习过程中保持一定的随机性。
L
PPO
(
θ
)
=
−
E
x
∼
p
,
y
∼
π
(
⋅
∣
x
)
1
∣
y
∣
∑
t
=
1
∣
y
∣
{
min
[
r
t
(
θ
)
A
t
,
clip
(
r
t
(
θ
)
,
1
−
ε
,
1
+
ε
)
A
t
]
−
β
1
KL
[
π
θ
old
(
a
t
∣
s
t
)
,
π
θ
ref
(
a
t
∣
s
t
)
]
−
β
2
∑
a
t
π
θ
(
a
t
∣
s
t
)
ln
π
θ
(
a
t
∣
s
t
)
}
(11)
L_{\text{PPO}}(\theta)=-\text{E}_{x\sim p,y\sim\pi(\cdot|x)}\frac{1}{|y|}\sum_{t=1}^{|y|}\Big\{\min[r_t(\theta)A_t,\text{clip}(r_t(\theta),1-\varepsilon,1+\varepsilon)A_t]\\ - \beta_1\text{KL}[\pi_{\theta_{\text{old}}}(a_t|s_t),\pi_{\theta_{\text{ref}}}(a_t|s_t)] - \beta_2\sum_{a_t}\pi_{\theta}(a_t|s_t)\ln\pi_{\theta}(a_t|s_t)\Big\} \tag{11}\\
LPPO(θ)=−Ex∼p,y∼π(⋅∣x)∣y∣1t=1∑∣y∣{min[rt(θ)At,clip(rt(θ),1−ε,1+ε)At]−β1KL[πθold(at∣st),πθref(at∣st)]−β2at∑πθ(at∣st)lnπθ(at∣st)}(11)
4.1 熵正则的计算
令
l
=
[
l
1
,
.
.
.
,
l
∣
v
∣
]
l=[l_1,...,l_{|v|}]
l=[l1,...,l∣v∣]表示未经过softmax
的logit。那么词表中第
i
i
i个token的概率为
p
i
=
e
l
i
∑
j
=
1
∣
v
∣
e
l
j
p_i=\frac{e^{l_i}}{\sum_{j=1}^{|v|}e^{l_j}} \\
pi=∑j=1∣v∣eljeli
两边取对数,有
log
p
i
=
l
i
−
log
∑
j
=
1
∣
v
∣
e
l
j
\log p_i = l_i - \log\sum_{j=1}^{|v|}e^{l_j} \\
logpi=li−logj=1∑∣v∣elj
那么信息熵为
−
∑
i
=
1
∣
v
∣
p
i
log
p
i
=
−
∑
i
=
1
∣
v
∣
p
i
(
l
i
−
log
∑
j
=
1
∣
v
∣
e
l
j
)
=
log
∑
j
=
1
∣
v
∣
e
l
j
−
∑
i
=
1
∣
v
∣
p
i
l
i
(12)
-\sum_{i=1}^{|v|} p_i\log p_i=-\sum_{i=1}^{|v|}p_i(l_i-\log\sum_{j=1}^{|v|}e^{l_j})=\log\sum_{j=1}^{|v|}e^{l_j}-\sum_{i=1}^{|v|}p_i l_i \tag{12}\\
−i=1∑∣v∣pilogpi=−i=1∑∣v∣pi(li−logj=1∑∣v∣elj)=logj=1∑∣v∣elj−i=1∑∣v∣pili(12)
verl中的熵正则计算代码为
# verl/utils/torch_functional.py
def entropy_from_logits(logits: torch.Tensor):
"""logits: [bs, response_length, vocab_size]"""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy
5. Critic模型loss的计算
Critic
模型的作用是预测状态价值函数
V
V
V的取值。如果有明确的标签,则使用MSE计算回归loss即可。
但是实际计算中并没有这个标签。因此,在Actor-Critic范式中使用包含了真实奖励的折扣回报作为标签。令
V
t
V_t
Vt是Critic
模型直接针对第
t
t
t个token预测的价值,
R
t
R_t
Rt则是通过优势计算出包含真实奖励的折扣回报,那么损失函数为
L
PPO
Critic
(
θ
)
=
E
x
∼
p
,
y
∼
π
(
⋅
∣
x
)
1
∣
y
∣
∑
t
=
1
∣
y
∣
1
2
max
(
(
V
t
−
R
t
)
2
,
(
clip
(
V
t
)
−
R
t
)
2
)
(13)
L_{\text{PPO}}^{\text{Critic}}(\theta) = \text{E}_{x\sim p,y\sim\pi(\cdot|x)}\frac{1}{|y|}\sum_{t=1}^{|y|}\frac{1}{2}\max\Big( (V_t-R_t)^2,(\text{clip}(V_t)-R_t)^2 \Big) \tag{13}\\
LPPOCritic(θ)=Ex∼p,y∼π(⋅∣x)∣y∣1t=1∑∣y∣21max((Vt−Rt)2,(clip(Vt)−Rt)2)(13)
这里也对
V
t
V_t
Vt进行了clip。下面是verl中的实现
# verl/trainer/ppo/core_algos.py
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
"""
vpreds:(`batch_size`, `response_length`)
values:(`batch_size`, `response_length`)
returns:(`batch_size`, `response_length`)
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns)**2
vf_losses2 = (vpredclipped - returns)**2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac
6. verl中PPO的整体流程
通过代码来看PPO的计算流程并不直观,为了更直观的理解绘制了下图。绿色的节点代表模型,红色的菱形代码处理逻辑,虚线的椭圆形代表输出。
二、GRPO
论文:DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models
1. 原理
PPO中需要额外训练一个值函数(Critic模型),这会带来很大的显存需求和计算负担。GRPO的目标就是去掉Critic模型。回忆一下,PPO中Critic模型主要是用于估计优势函数中的 V V V,从而进一步计算优势 A t = Q ( s t , a t ) − V ( s t ) A_t=Q(s_t,a_t)-V(s_t) At=Q(st,at)−V(st)。
在优势计算中之所以减去 V ( s t ) V(s_t) V(st),是因为直接利用真实奖励蒙特卡洛估计 Q ( s t , a t ) Q(s_t,a_t) Q(st,at)会导致高方差,而 V ( s t ) V(s_t) V(st)作为 Q ( s t , a t ) Q(s_t,a_t) Q(st,at)的期望,充当了baseline。那么,除了减去baseline实现降低方差外,也可以通过多次采样求均值的方式来减少方差。因此GRPO相较于PPO的主要改进为:(1) 单个样本多次采取求均值减少方差;(2) 优势计算中去掉了对值函数的依赖。
具体来说,损失函数更新为
L
GRPO
(
θ
)
=
−
E
x
∼
p
,
{
y
}
i
=
1
G
∼
π
(
⋅
∣
x
)
1
G
∑
i
=
1
G
1
∣
y
i
∣
∑
t
=
1
∣
y
i
∣
{
min
[
r
i
,
t
(
θ
)
A
i
,
t
,
clip
(
r
i
,
t
(
θ
)
,
1
−
ε
,
1
+
ε
)
A
i
,
t
]
−
β
KL
[
π
θ
old
(
a
i
,
t
∣
s
i
,
t
)
,
π
θ
ref
(
a
i
,
t
∣
s
i
,
t
)
]
}
(14)
L_{\text{GRPO}}(\theta)=-\text{E}_{x\sim p,\{y\}_{i=1}^G\sim\pi(\cdot|x)}\frac{1}{G}\sum_{i=1}^G\frac{1}{|y_i|}\sum_{t=1}^{|y_i|}\Big\{\min[r_{i,t}(\theta)A_{i,t},\text{clip}(r_{i,t}(\theta),1-\varepsilon,1+\varepsilon)A_{i,t}]\\ -\beta\text{KL}[\pi_{\theta_{\text{old}}}(a_{i,t}|s_{i,t}),\pi_{\theta_{\text{ref}}}(a_{i,t}|s_{i,t})]\Big\} \tag{14}\\
LGRPO(θ)=−Ex∼p,{y}i=1G∼π(⋅∣x)G1i=1∑G∣yi∣1t=1∑∣yi∣{min[ri,t(θ)Ai,t,clip(ri,t(θ),1−ε,1+ε)Ai,t]−βKL[πθold(ai,t∣si,t),πθref(ai,t∣si,t)]}(14)
其中重要性采样为
r
i
,
t
(
θ
)
=
π
θ
(
a
i
,
t
∣
s
i
,
t
)
π
θ
old
(
a
i
,
t
∣
s
i
,
t
)
r_{i,t}(\theta)=\frac{\pi_{\theta}(a_{i,t}|s_{i,t})}{\pi_{\theta_{\text{old}}}(a_{i,t}|s_{i,t})}
ri,t(θ)=πθold(ai,t∣si,t)πθ(ai,t∣si,t)。在GRPO的论文中,KL散度指定使用
k
3
k3
k3估计器。
此外,由于单个样本
x
x
x会进行G次采样,则会得到一组输出
{
o
1
,
o
2
,
…
,
o
G
}
\{o_1,o_2,\dots,o_G\}
{o1,o2,…,oG},对应的得到一组奖励
r
=
{
r
1
,
r
2
,
…
,
r
G
}
r=\{r_1,r_2,\dots,r_G\}
r={r1,r2,…,rG}。GRPO将normalize后的奖励作为优势函数的取值
A
i
,
t
=
r
~
i
=
r
i
−
mean
(
r
)
std
(
r
)
(15)
A_{i,t}=\tilde{r}_i=\frac{r_i-\text{mean}(r)}{\text{std}(r)} \tag{15}\\
Ai,t=r~i=std(r)ri−mean(r)(15)
2. 实现
优势计算的代码如下:
# verl/trainer/ppo/core_algos.py
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
token_level_rewards: (bs, response_length)
eos_mask: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
# id2score的value是一个列表,其包含了G个奖励值,key则是为每个输入x生成的id
id2score = defaultdict(list)
# id2mean和id2std分别存储每组的均值和标准差
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
breakpoint()
for i in range(bsz):
# normalize
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
除了优势计算需要重写外,整个loss的计算直接复用PPO的即可,只要指定每个输入
x
x
x的重复采样次数actor_rollout_ref.rollout.n
即可。
三、REINFORCE++
论文:REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models
1. 原理
REINFORCE++的动机仍然是去掉PPO中的Critic模型。相比于GRPO采用组内平均的方式减少方差,REINFORCE++则更加直接,其在保留了PPO绝大数trick的情况下直接去掉了baseline。
简单来说,去掉优势函数中的baseline
V
(
s
t
)
V(s_t)
V(st),那么优势为
A
t
=
Q
(
s
t
,
a
t
)
A_t=Q(s_t,a_t)
At=Q(st,at)。采用蒙特卡洛估计来计算
Q
(
s
t
,
a
t
)
Q(s_t,a_t)
Q(st,at),即使用折扣回报
G
t
=
∑
k
=
t
+
1
T
γ
k
−
t
r
k
G_t=\sum_{k=t+1}^T\gamma^{k-t}r_k
Gt=∑k=t+1Tγk−trk作为优势
A
t
A_t
At的估计值。此外,为了保持稳定的梯度,再对计算出来的优势进行normalize
A
~
t
=
A
t
−
μ
A
σ
A
(16)
\tilde{A}_t=\frac{A_t-\mu_A}{\sigma_A} \tag{16}\\
A~t=σAAt−μA(16)
其中,
μ
A
\mu_A
μA和
σ
A
\sigma_A
σA是整个batch优势的均值和标准差。
2. 实现
verl中优势和回报的计算
# verl/trainer/ppo/core_algos.py
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
gamma: torch.Tensor):
"""
token_level_rewards:(bs, response_length)
eos_mask:(bs, response_length)
"""
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
return advantages, returns
下面是normalize的代码,其中mean和var都是标量值
def masked_whiten(values, mask, shift_mean=True):
# mean和var都是标量值
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
四、RLOO
论文:Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs
1. 原理
RLOO同样是要去掉Critic模型,但采取了不同于GRPO的baseline估计方式。类似于GRPO,通常是单个样本进行多次采样。假设对样本
x
x
x进行G次采样,则会得到一组输出
{
o
1
,
o
2
,
…
,
o
G
}
\{o_1,o_2,\dots,o_G\}
{o1,o2,…,oG},对应的得到一组奖励
r
=
{
r
1
,
r
2
,
…
,
r
G
}
r=\{r_1,r_2,\dots,r_G\}
r={r1,r2,…,rG}。那么
[
x
;
o
i
]
[x;o_i]
[x;oi]的优势函数为
A
i
=
1
G
∑
i
=
1
G
[
r
i
−
1
G
−
1
∑
j
≠
i
r
j
]
(17)
A_i=\frac{1}{G}\sum_{i=1}^G\Big[r_i-\frac{1}{G-1}\sum_{j\neq i}r_j\Big] \tag{17}\\
Ai=G1i=1∑G[ri−G−11j=i∑rj](17)
这样通过采样一组输出后,将除自身以外的奖励均值作为baseline。
2. 实现
# verl/trainer/ppo/core_algos.py
def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
token_level_rewards:(bs, response_length)
eos_mask:(bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
# (bs,),聚合token_level_rewards形成每个轨迹的奖励
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {} # 组内均值
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
breakpoint()
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (response_num - 1)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
乍看起来,这里的实现与论文中并不一致。这里简单推导一下实现和论文的关系
A
i
=
G
G
−
1
⋅
r
i
−
G
G
−
1
⋅
1
G
∑
j
=
1
G
r
j
=
G
G
−
1
⋅
r
i
−
1
G
−
1
∑
j
=
1
G
r
j
=
G
−
1
G
−
1
⋅
r
i
+
1
G
−
1
⋅
r
i
−
1
G
−
1
(
r
i
+
∑
j
≠
i
r
j
)
=
r
i
−
1
G
−
1
∑
j
≠
i
r
j
\begin{align} A_i &= \frac{G}{G-1}\cdot r_i-\frac{G}{G-1}\cdot\frac{1}{G}\sum_{j=1}^G r_j \\ &= \frac{G}{G-1}\cdot r_i-\frac{1}{G-1}\sum_{j=1}^G r_j \\ &= \frac{G-1}{G-1}\cdot r_i + \frac{1}{G-1}\cdot r_i - \frac{1}{G-1}\Big(r_i +\sum_{j\neq i}r_j\Big) \\ &=r_i-\frac{1}{G-1}\sum_{j\neq i}r_j \end{align} \\
Ai=G−1G⋅ri−G−1G⋅G1j=1∑Grj=G−1G⋅ri−G−11j=1∑Grj=G−1G−1⋅ri+G−11⋅ri−G−11(ri+j=i∑rj)=ri−G−11j=i∑rj
参考资料
http://joschu.net/blog/kl-approx.html
https://arxiv.org/html/2501.03262v1
https://arxiv.org/pdf/2402.03300
https://arxiv.org/pdf/2402.14740
https://zhuanlan.zhihu.com/p/25208314999
https://zhuanlan.zhihu.com/p/675309680
https://zhuanlan.zhihu.com/p/675348061