GSPO如何消除高方差且不依赖routing replay
GSPO算法的创新
GSPO算法的改进在于将GRPO算法的token级重要性采样比值替换为序列重要性采样, 设计动机是修复GRPO算法的一个bug, reward是序列级别的(生成完整序列后才计算得到), 而重要性采样是token级别的, 天然不对等导致很多不稳定性问题. 具体优化项为:
s i ( θ ) = ( π θ ( y i ∣ x ) π θ old ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log π θ ( y i , t ∣ x , y i , < t ) π θ old ( y i , t ∣ x , y i , < t ) ) s_i(\theta)=\left(\frac{\pi_\theta\left(y_i \mid x\right)}{\pi_{\theta_{\text {old }}}\left(y_i \mid x\right)}\right)^{\frac{1}{\left|y_i\right|}}=\exp \left(\frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_\theta\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) si(θ)=(πθold (yi∣x)πθ(yi∣x))∣yi∣1=exp ∣yi∣1t=1∑∣yi∣logπθold (yi,t∣x,yi,<t)πθ(yi,t∣x,yi,<t)
这个改动有一箭双雕的作用:
- 消除GRPO算法存在高方差和不稳定性的问题
- 消除MoE架构下GRPO算法对routing replay的依赖
GSPO算法的优势
消除GRPO算法token级别重要性采样导致的高方差
GSPO效果上能够实现消除高方差, 可以从几何平均的角度来理解. 如果用算数平均, 无论是先加权(GRPO)还是先平均(GSPO)数学意义上是等价的. 算数平均与几何平均的差异可以参考这篇博客, 里面提到几何平均适用于相对量的平均, 方差就是关于算数平均的相对值, 几何平均后能够减小方差.
不考虑clip的情况下, GRPO的目标函数为:
J G R P O ( θ ) = E [ 1 G ∑ i = 1 G 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ ( w i , t ( θ ) A ^ i ) ] w i , t ( θ ) = π θ ( y i , t ∣ x , y i < t ) π θ o l d ( y i , t ∣ x , y i < t ) \mathcal{J}_{\mathrm{GRPO}}(\theta)=\mathbb{E}\left[\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|}\left(w_{i, t}(\theta) \widehat{A}_i\right)\right] \\ w_{i, t}(\theta) = \frac{\pi_\theta(y_{i,t} \mid x, y_{i<t})}{\pi_{\theta_{old}}(y_{i,t} \mid x, y_{i<t})} JGRPO(θ)=E G1i=1∑G∣yi∣1t=1∑∣yi∣(wi,t(θ)A i) wi,t(θ)=πθold(yi,t∣x,yi<t)πθ(yi,t∣x,yi<t)
GSPO目标函数为:
$$
\mathcal{J}{\mathrm{GSPO}}(\theta)=\mathbb{E}\left[\frac{1}{G} \sum{i=1}^G s_{i}(\theta) \widehat{A}i\right] \
s_i(\theta) = \frac{1}{\left|y_i\right|}\sum{t=1}^{\left|y_i\right|}\frac{\pi_\theta(y_{i,t} \mid x, y_{i<t})}{\pi_{\theta_{old}}(y_{i,t} \mid x, y_{i<t})}\text{算数平均} \
s_i(\theta)=\left(\frac{\pi_\theta\left(y_i \mid x\right)}{\pi_{\theta_{\text {old }}}\left(y_i \mid x\right)}\right)^{\frac{1}{\left|y_i\right|}}=\exp \left(\frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_\theta\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right)\text{几何平均}
$$
观察上面两个目标函数, 若GSPO使用算数平均, 与GRPO在数学意义上等价. 在考虑clip的情况序列级别的重要性比值设计也起到了作用, GRPO中相当于clip对token级别存在高方差的重要比值进行操作, 引入不稳定性, GSPO中clip对序列级别的低方差重要性比值进行操作.
从几何平均和算数平均角度只是从一个具象的可解释的角度来阐述GSPO如何消除高方差这一抽象概念, 并不是说两个都正确, 对于序列重要性采样几何平均才是正确形式. 从正确性角度来说, 首先GRPO算法中token级别的重要性采样和序列级别的reward之间存在不对等的错误; 其次序列级别的重要性采样需要对token的概率进行连乘, 进行归一化且化简后数学形式上就是几何平均, 若用算数平均则存在逻辑错误.
消除MOE架构下GRPO算法对routing replay的依赖
GSPO消除对routing replay的依赖可以从边缘概率的角度来理解. 首先, MoE架构下GRPO方法token级别的重要性采样可以理解为以下联合概率分布:
π θ ( y i , t ∣ x , y i , < t ) = p ( r t ∣ x ) π θ ( y i , t ∣ x , r t ) \pi_{\theta}(y_{i, t} \mid x, y_{i, <t}) = p(r_t \mid x)\pi_\theta(y_{i,t} \mid x, r_t) πθ(yi,t∣x,yi,<t)=p(rt∣x)πθ(yi,t∣x,rt)
p ( r t ∣ x ) p(r_t \mid x) p(rt∣x)表示上下文条件 x x x下选择专家 r t r_t rt的概率, 即专家路由router的分布, π θ \pi_{\theta} πθ表示生成token的概率分布. 边缘概率是指在联合概率分布中,通过对其他变量进行求和或积分,得到的某个变量的概率分布.
在MoE架构的模型中, 按照传统token级别的重要性采样(如GRPO方法), 无法对router概率分布进行完全表示, 无法对联合概率分布解耦, 工程上只能用routing replay的形式严格保证采样和训练时router的概率分布一致来实现对生成策略 π θ ( y i , t ∣ x ) \pi_\theta(y_{i,t} \mid x) πθ(yi,t∣x)的优化.
而GSPO方法中, 按照序列级别进行重要性采样, 对于序列 y y y来说生成概率为 π θ ( y ∣ x ) \pi_\theta(y\mid x) πθ(y∣x). 可以表示为对router变量的积分形式:
π θ ( y ∣ x ) = ∑ r π θ ( y , r ∣ x ) \pi_\theta(y\mid x) = \sum_r \pi_\theta(y, r \mid x) πθ(y∣x)=r∑πθ(y,r∣x)
即序列概率天然隐式包含了对所有可能路由 r r r的边缘化, 即已经对router变量积分, 那么序列概率分布就是仅关于token生成策略 π θ ( y i , t ∣ x ) \pi_\theta(y_{i,t}\mid x) πθ(yi,t∣x)的边缘概率, 因此不再需要routing replay.