GRPO与GSPO算法训练对比
项目地址: GRPO vs GSPO
GRPO原文: Group Relative Policy Optimization
GSPO原文: Group Sequence Policy Optimization
- 数据集: GSM8K
- 参考模型: qwen2.5-1.5B-Instruct
- 目标模型: qwen2.5-1.5B-Instruct
- 硬件配置: 3 × AutoDL vGPU-32G (GPU0/1用于训练, GPU2用于采样)
- 训练步数: 200 steps (60min)
准确率评估包含答案和格式两部分:
- GSPO算法在50个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右
- GRPO算法在120个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右
从结果来看GSPO训练速度明显优于GRPO, 消耗更少的时间达到稳定状态. 从模型特性来解释, GSPO模型训练时方差更小, 在矫正输出分布时有更强的确定性能够快速调整, 宏观上体现为更快得收敛至稳定值. 训练至200步后两种方法训练的结果基本接近, 应该是达到模型极限.
GRPO算法
目标函数:
JGRPO(θ)=E[1G∑i=1G1∣yi∣∑t=1∣yi∣min(wi,t(θ)A^i,t,clip(wi,t(θ),1−ϵ,1+ϵ)A^i,t)]J_{GRPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \min \left(w_{i, t}(\theta) \hat{A}_{i, t}, \text{clip}\left(w_{i, t}(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_{i, t}\right)\right] JGRPO(θ)=EG1i=1∑G∣yi∣1t=1∑∣yi∣min(wi,t(θ)A^i,t,clip(wi,t(θ),1−ϵ,1+ϵ)A^i,t)
wi,t(θ)=πθ(yi,t∣x,yi<t)πθold(yi,t∣x,yi<t)w_{i, t}(\theta) = \frac{\pi_\theta(y_{i,t} \mid x, y_{i<t})}{\pi_{\theta_{old}}(y_{i,t} \mid x, y_{i<t})} wi,t(θ)=πθold(yi,t∣x,yi<t)πθ(yi,t∣x,yi<t)
A^i,t=A^i=r(x,yi)−mean({r(x,yi)}i=1G)std({r(x,yi)}i=1G)\hat{A}_{i, t}=\hat{A}_i=\frac{r(x, y_i)-\text{mean}(\{r(x, y_i)\}_{i=1}^G)}{\text{std}(\{r(x, y_i)\}_{i=1}^G)} A^i,t=A^i=std({r(x,yi)}i=1G)r(x,yi)−mean({r(x,yi)}i=1G)
其中wi,tw_{i,t}wi,t表示token级别的重要性采样, A^i\hat{A}_{i}A^i表示序列的组内回报值:
代码实现:
ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # 参考策略概率分布
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:] # 旧策略概率分布
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:] # 新策略概率分布
attention_mask_ = attention_mask[:, prefix_len:]importance_ratio = torch.exp(new_policy_log_probs_ - old_policy_log_probs_) # 重要性采样
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # 相似度裁剪
importance_term = importance_ratio * advantages
clip_term = cliped_ratio * advantageskl_term = torch.exp(ref_policy_log_probs_ - new_policy_log_probs_) - (ref_policy_log_probs_ - new_policy_log_probs_) - 1 # kl散度objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term # 目标函数
per_token_loss = -objective_function # loss函数loss = ((per_token_loss * attention_mask_).sum(dim=1) / attention_mask_.sum(dim=1)).mean() # batch的均值作为最终loss(只统计有效token的loss)
GSPO算法
目标函数:
JGSPO(θ)=E[1G∑i=1Gmin(si(θ)A^i,clip(si(θ),1−ϵ,1+ϵ)A^i)]J_{GSPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \min \left(s_i(\theta) \hat{A}_i, \text{clip}\left(s_i(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_i\right)\right] JGSPO(θ)=E[G1i=1∑Gmin(si(θ)A^i,clip(si(θ),1−ϵ,1+ϵ)A^i)]
si(θ)=(πθ(yi∣x)πθold (yi∣x))1∣yi∣=exp(1∣yi∣∑t=1∣yi∣logπθ(yi,t∣x,yi,<t)πθold (yi,t∣x,yi,<t))s_i(\theta)=\left(\frac{\pi_\theta\left(y_i \mid x\right)}{\pi_{\theta_{\text {old }}}\left(y_i \mid x\right)}\right)^{\frac{1}{\left|y_i\right|}}=\exp \left(\frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_\theta\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) si(θ)=(πθold (yi∣x)πθ(yi∣x))∣yi∣1=exp∣yi∣1t=1∑∣yi∣logπθold (yi,t∣x,yi,<t)πθ(yi,t∣x,yi,<t)
其中si(θ)s_i(\theta)si(θ)表示序列重要性采样, 与序列组内回报A^i\hat{A}_iA^i颗粒度是对齐的.
代码实现:
batch_size = ref_policy_log_probs.shape[0]# 取生成部分的概率分布
ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # token_0裁剪了, 因此需要裁剪的长度为prefix_len-1
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:]
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:]
attention_mask_ = attention_mask[:, prefix_len:] # attention_mask维度中token_0的位置没裁剪, 因此需要裁剪的长度为prefix_len# 计算有效序列, 遮掩pad_token
valid_seq_len = attention_mask_.sum(dim=1)
new_old_log_probs_ = (new_policy_log_probs_ - old_policy_log_probs_) * attention_mask_
ref_new_log_probs_ = (ref_policy_log_probs_ - new_policy_log_probs_) * attention_mask_# 序列级别的重要性采样
importance_ratio = torch.exp(new_old_log_probs_.sum(dim=1) / valid_seq_len).view(batch_size, 1) # batch_size * 1
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # batch_size * 1
importance_term = importance_ratio * advantages # batch_size * 1
clip_term = cliped_ratio * advantages # batch_size * 1kl_term = torch.exp(ref_new_log_probs_.sum(dim=1) / valid_seq_len) - (ref_new_log_probs_.sum(dim=1) / valid_seq_len) - 1
kl_term = kl_term.view(batch_size, 1)objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term
sequence_loss = -objective_function# 批次平均损失作为总损失
loss = sequence_loss.mean()