当前位置: 首页 > news >正文

GRPO与GSPO算法训练对比

项目地址: GRPO vs GSPO
GRPO原文: Group Relative Policy Optimization
GSPO原文: Group Sequence Policy Optimization

  • 数据集: GSM8K
  • 参考模型: qwen2.5-1.5B-Instruct
  • 目标模型: qwen2.5-1.5B-Instruct
  • 硬件配置: 3 × AutoDL vGPU-32G (GPU0/1用于训练, GPU2用于采样)
  • 训练步数: 200 steps (60min)

grpo_vs_gspo

准确率评估包含答案和格式两部分:

  • GSPO算法在50个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右
  • GRPO算法在120个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右

从结果来看GSPO训练速度明显优于GRPO, 消耗更少的时间达到稳定状态. 从模型特性来解释, GSPO模型训练时方差更小, 在矫正输出分布时有更强的确定性能够快速调整, 宏观上体现为更快得收敛至稳定值. 训练至200步后两种方法训练的结果基本接近, 应该是达到模型极限.

GRPO算法

目标函数:
JGRPO(θ)=E[1G∑i=1G1∣yi∣∑t=1∣yi∣min⁡(wi,t(θ)A^i,t,clip(wi,t(θ),1−ϵ,1+ϵ)A^i,t)]J_{GRPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \min \left(w_{i, t}(\theta) \hat{A}_{i, t}, \text{clip}\left(w_{i, t}(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_{i, t}\right)\right] JGRPO(θ)=EG1i=1Gyi1t=1yimin(wi,t(θ)A^i,t,clip(wi,t(θ),1ϵ,1+ϵ)A^i,t)
wi,t(θ)=πθ(yi,t∣x,yi<t)πθold(yi,t∣x,yi<t)w_{i, t}(\theta) = \frac{\pi_\theta(y_{i,t} \mid x, y_{i<t})}{\pi_{\theta_{old}}(y_{i,t} \mid x, y_{i<t})} wi,t(θ)=πθold(yi,tx,yi<t)πθ(yi,tx,yi<t)
A^i,t=A^i=r(x,yi)−mean({r(x,yi)}i=1G)std({r(x,yi)}i=1G)\hat{A}_{i, t}=\hat{A}_i=\frac{r(x, y_i)-\text{mean}(\{r(x, y_i)\}_{i=1}^G)}{\text{std}(\{r(x, y_i)\}_{i=1}^G)} A^i,t=A^i=std({r(x,yi)}i=1G)r(x,yi)mean({r(x,yi)}i=1G)

其中wi,tw_{i,t}wi,t表示token级别的重要性采样, A^i\hat{A}_{i}A^i表示序列的组内回报值:

代码实现:

ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # 参考策略概率分布
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:] # 旧策略概率分布
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:] # 新策略概率分布
attention_mask_       = attention_mask[:, prefix_len:]importance_ratio = torch.exp(new_policy_log_probs_ - old_policy_log_probs_) # 重要性采样
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # 相似度裁剪
importance_term = importance_ratio * advantages
clip_term = cliped_ratio * advantageskl_term = torch.exp(ref_policy_log_probs_ - new_policy_log_probs_) - (ref_policy_log_probs_ - new_policy_log_probs_) - 1 # kl散度objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term # 目标函数
per_token_loss = -objective_function # loss函数loss = ((per_token_loss * attention_mask_).sum(dim=1) / attention_mask_.sum(dim=1)).mean() # batch的均值作为最终loss(只统计有效token的loss)

GSPO算法

目标函数:
JGSPO(θ)=E[1G∑i=1Gmin⁡(si(θ)A^i,clip(si(θ),1−ϵ,1+ϵ)A^i)]J_{GSPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \min \left(s_i(\theta) \hat{A}_i, \text{clip}\left(s_i(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_i\right)\right] JGSPO(θ)=E[G1i=1Gmin(si(θ)A^i,clip(si(θ),1ϵ,1+ϵ)A^i)]
si(θ)=(πθ(yi∣x)πθold (yi∣x))1∣yi∣=exp⁡(1∣yi∣∑t=1∣yi∣log⁡πθ(yi,t∣x,yi,<t)πθold (yi,t∣x,yi,<t))s_i(\theta)=\left(\frac{\pi_\theta\left(y_i \mid x\right)}{\pi_{\theta_{\text {old }}}\left(y_i \mid x\right)}\right)^{\frac{1}{\left|y_i\right|}}=\exp \left(\frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_\theta\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) si(θ)=(πθold (yix)πθ(yix))yi1=expyi1t=1yilogπθold (yi,tx,yi,<t)πθ(yi,tx,yi,<t)

其中si(θ)s_i(\theta)si(θ)表示序列重要性采样, 与序列组内回报A^i\hat{A}_iA^i颗粒度是对齐的.

代码实现:

batch_size = ref_policy_log_probs.shape[0]# 取生成部分的概率分布
ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # token_0裁剪了, 因此需要裁剪的长度为prefix_len-1
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:]
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:]
attention_mask_       = attention_mask[:, prefix_len:]         # attention_mask维度中token_0的位置没裁剪, 因此需要裁剪的长度为prefix_len# 计算有效序列, 遮掩pad_token
valid_seq_len = attention_mask_.sum(dim=1)
new_old_log_probs_ = (new_policy_log_probs_ - old_policy_log_probs_) * attention_mask_
ref_new_log_probs_ = (ref_policy_log_probs_ - new_policy_log_probs_) * attention_mask_# 序列级别的重要性采样
importance_ratio = torch.exp(new_old_log_probs_.sum(dim=1) / valid_seq_len).view(batch_size, 1) # batch_size * 1
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # batch_size * 1
importance_term = importance_ratio * advantages # batch_size * 1
clip_term = cliped_ratio * advantages # batch_size * 1kl_term = torch.exp(ref_new_log_probs_.sum(dim=1) / valid_seq_len) - (ref_new_log_probs_.sum(dim=1) / valid_seq_len) - 1
kl_term = kl_term.view(batch_size, 1)objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term
sequence_loss = -objective_function# 批次平均损失作为总损失
loss = sequence_loss.mean()
http://www.dtcms.com/a/475248.html

相关文章:

  • 如何制作网站板块php 企业网站模板
  • 佛山网站制作好处wordpress 扣积分
  • linux重定向中 >file 2>1,>>file 2>1 , >>file是什莫意思
  • 网站引导插件做网站最好的软件是
  • C++ 泛型
  • 网站网站建设公司企业为什么要增资
  • 第9章:两条道路的风景:技术与管理的真实世界(3)
  • Python 基础教程 | 菜鸟教程
  • 建设网站需求劳务公司简介模板
  • 解决 Vue 3 + TypeScript 中 v-for 循环类型推断问题
  • 外贸网站建站注意事项及价格宣传片拍摄脚本范本
  • Linux碎碎念:网络抓包利器:tcpdump 使用与分析入门
  • 十堰网站建设是什么塔罗牌手机网站制作
  • 北京网站制作费用wampserver安装wordpress
  • c可以做网站么公司网站域名无法解析
  • 做php网站教程视频住建部网站统计城乡建设统计信息系统登录
  • 风铃网站具体是做那方面的网站后台演示地址
  • 网站 建设 内容网站后台登录界面下载
  • 园林效果图网站兰州网站排名优化服务
  • Starting again-03
  • 探秘编译器背后的语言密码:从底层实现到技术演进的全景图
  • iis 里没有网站吗深圳的网站建设公司三把火
  • 肇庆企业建站程序evernote wordpress
  • JavaWeb学习-web开发什么是web开发
  • 专业开发网站企业net网站开发net网站开发
  • 最专业的企业营销型网站建设5分钟建站wordpress
  • JavaEE--Spring MVC
  • 建设网站简单的需要多少天网站开发技术要学什么软件
  • XCP协议在以太网上实现的配置
  • 榆林高端网站建设如何设计苏州做网站的公司有哪些