当前位置: 首页 > news >正文

trl GRPO源码分析:如何处理多个reward function?

在GRPOTrainer源码中,多个reward function的处理主要通过以下步骤实现:

1. 初始化reward functions

  • reward_funcs参数可以是单个reward function或一个列表。在初始化时,它被统一转换为列表:

    if not isinstance(reward_funcs, list):reward_funcs = [reward_funcs]
    
  • 每个reward function可以是以下类型:

    • 字符串:被视为预训练模型的ID,使用AutoModelForSequenceClassification.from_pretrained加载,设置num_labels=1
    • PreTrainedModel:直接使用序列分类模型。
    • 自定义函数:一个可调用对象,接受prompts和completions等参数,返回reward列表。自定义函数可以返回None表示不适用于某些样本。
  • 同时,记录每个reward function的名称用于日志:

    for i, reward_func in enumerate(reward_funcs):if isinstance(reward_func, str):reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(...)if isinstance(reward_funcs[i], nn.Module):self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])else:self.reward_func_names.append(reward_funcs[i].__name__)
    

2. 设置reward weights

  • 如果提供了args.reward_weights,则使用自定义权重,否则默认所有reward function的权重为1:
    if args.reward_weights is not None:if len(args.reward_weights) != len(reward_funcs):raise ValueError("Number of reward weights must match number of reward functions")self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
    else:self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
    
  • 权重用于后续的加权求和。

3. 处理reward processing classes

  • 每个reward function可以有一个对应的processing class(如tokenizer),用于预处理输入。如果未提供,则自动加载:
    for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):if isinstance(reward_func, PreTrainedModel):if reward_processing_class is None:reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)# 设置pad tokenreward_processing_class.pad_token = reward_processing_class.eos_tokenreward_func.config.pad_token_id = reward_processing_class.pad_token_idreward_processing_classes[i] = reward_processing_class
    
  • 对于自定义reward function,对应的processing class被忽略。

4. 计算rewards

  • _calculate_rewards方法中,遍历所有reward function,计算每个function的reward:
    rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
    for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(...):if isinstance(reward_func, nn.Module):# 对于模型reward function:将prompts和completions组合,tokenize,然后通过模型计算rewardtexts = [p + c for p, c in zip(prompts, completions)]  # 或应用chat templatereward_inputs = reward_processing_class(text=texts, padding=True, ...)with torch.inference_mode():rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]else:# 对于自定义reward function:直接调用函数,传入prompts、completions和其他参数output_reward_func = reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs)# 处理None值:转换为NaNoutput_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
    
  • 这里,reward_kwargs包括数据集中除"prompt"、“completion”、"completion_ids"外的所有列,以及trainer_state(训练状态)。

5. 加权求和rewards

  • 所有reward function的输出被组合成一个张量rewards_per_func(形状为[batch_size, num_reward_funcs]),然后与权重相乘并求和:
    rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
    
  • 使用nansum是为了处理自定义reward function返回None(转换为NaN)的情况。如果所有reward function对某个样本都返回NaN,则最终reward为NaN,并会发出警告。

6. 处理NaN情况

  • 如果所有reward function对某个样本都返回NaN,会记录警告:
    if torch.isnan(rewards_per_func).all(dim=1).any():nan_row_idx = ...  # 找到第一个全NaN的行logger.warning("All reward functions returned None for the following kwargs: ...")
    

7. 日志记录

  • 每个reward function的mean和std被记录到metrics中,便于监控:
    for i, reward_func_name in enumerate(self.reward_func_names):mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)std_func_rewards = nanstd(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
    
  • 这些日志在训练和评估时都会更新。

8. 优势计算和归一化

  • 加权后的rewards用于计算优势(advantages)。rewards首先按prompt分组(每组包含num_generations个completion),计算组内mean,然后归一化:
    mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    advantages = rewards - mean_grouped_rewards
    
  • 优势可以按组或批量缩放(通过scale_rewards参数控制)。

总结

多个reward function的处理是灵活的:支持混合类型(模型和自定义函数),允许自定义权重,并处理不适用情况(返回None)。最终reward是加权和,同时每个function的贡献被单独记录用于监控。这种设计适用于多任务学习,其中不同reward函数可能应用于不同类型样本。


文章转载自:

http://I11TlgS2.bLznh.cn
http://M4nBJ08v.bLznh.cn
http://h0c2Fxry.bLznh.cn
http://NE9qjKUR.bLznh.cn
http://kY9sRuGY.bLznh.cn
http://ZX1QWaOX.bLznh.cn
http://eHL1wA6c.bLznh.cn
http://lIEhXRqJ.bLznh.cn
http://6JmjCn5Y.bLznh.cn
http://jVXxJV6M.bLznh.cn
http://IewhkBm8.bLznh.cn
http://64gYtozl.bLznh.cn
http://qamcrIri.bLznh.cn
http://0AN4TUIe.bLznh.cn
http://FfSjSt9x.bLznh.cn
http://OlJ6hpRV.bLznh.cn
http://WIXtonNG.bLznh.cn
http://zD5ScG4Z.bLznh.cn
http://Ti657mll.bLznh.cn
http://qdrwmWnQ.bLznh.cn
http://cDBH1fWM.bLznh.cn
http://QwecvLC7.bLznh.cn
http://1Wo36inL.bLznh.cn
http://fQbjhswQ.bLznh.cn
http://GjiyIiUr.bLznh.cn
http://kRd0doxJ.bLznh.cn
http://YyL2pQS7.bLznh.cn
http://igjf2QDb.bLznh.cn
http://5k3CjnAU.bLznh.cn
http://7lUz221H.bLznh.cn
http://www.dtcms.com/a/372105.html

相关文章:

  • 临床研究三千问——临床研究体系的3个维度(8)
  • TypeORM入门教程:@JoinColumn和@OneToOne的关系
  • html列表标签之无序列表
  • [1]-01-创建空工程
  • 【模型训练篇】VeRL核心思想 - 论文HybridFlow
  • pycharm设置编辑区字体大小
  • 鸿蒙NEXT跨设备数据同步实战:分布式应用开发指南
  • C++ 中栈 (Stack) 详解和常见面试示例汇总实现
  • [光学原理与应用-461]:波动光学 - 波片实现偏振态的转换或调整
  • 苍穹外卖Day12 | Apache POI、导出Excel报表、HttpServletResponse、工作台
  • 《Go小技巧易错点100例》第三十八篇
  • Conda 包管理器与环境管理使用指南
  • 笔记本、平板如何成为电脑拓展屏?向日葵16成为副屏功能一键实现
  • OpenHarmony 显示能效管理组件:掌控屏幕亮灭与亮度的核心利器
  • SQLite的基本操作
  • 第五课 C#语言基本元素概览,初始类型,变量与方法,算法简介
  • 【系统分析师】第12章-关键技术:软件架构设计(核心总结)
  • Lightdash:一个免费开源的自助式BI平台
  • Claude Code 使用教程
  • UML(统一建模语言)
  • Android开发-常用布局
  • Spring Cloud Gateway 进行集群化部署
  • EmbodiedOneVision——类似π0.5集成了离散自回归解码与连续流匹配去噪:单个模型中完成具身推理、动作生成
  • Paper reading - 03. Speech sequencing in the human precentral gyrus
  • Spring事务失效的常见陷阱与解决方案
  • 现代C++:现代C++?
  • ZSet
  • Linux初级篇
  • MySQL集群高可用架构——组复制 (MGR)
  • MySQL Cluster核心优缺点