trl GRPO源码分析:如何处理多个reward function?
在GRPOTrainer源码中,多个reward function的处理主要通过以下步骤实现:
1. 初始化reward functions
-
reward_funcs
参数可以是单个reward function或一个列表。在初始化时,它被统一转换为列表:if not isinstance(reward_funcs, list):reward_funcs = [reward_funcs]
-
每个reward function可以是以下类型:
- 字符串:被视为预训练模型的ID,使用
AutoModelForSequenceClassification.from_pretrained
加载,设置num_labels=1
。 - PreTrainedModel:直接使用序列分类模型。
- 自定义函数:一个可调用对象,接受prompts和completions等参数,返回reward列表。自定义函数可以返回
None
表示不适用于某些样本。
- 字符串:被视为预训练模型的ID,使用
-
同时,记录每个reward function的名称用于日志:
for i, reward_func in enumerate(reward_funcs):if isinstance(reward_func, str):reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(...)if isinstance(reward_funcs[i], nn.Module):self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])else:self.reward_func_names.append(reward_funcs[i].__name__)
2. 设置reward weights
- 如果提供了
args.reward_weights
,则使用自定义权重,否则默认所有reward function的权重为1:if args.reward_weights is not None:if len(args.reward_weights) != len(reward_funcs):raise ValueError("Number of reward weights must match number of reward functions")self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) else:self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
- 权重用于后续的加权求和。
3. 处理reward processing classes
- 每个reward function可以有一个对应的processing class(如tokenizer),用于预处理输入。如果未提供,则自动加载:
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):if isinstance(reward_func, PreTrainedModel):if reward_processing_class is None:reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)# 设置pad tokenreward_processing_class.pad_token = reward_processing_class.eos_tokenreward_func.config.pad_token_id = reward_processing_class.pad_token_idreward_processing_classes[i] = reward_processing_class
- 对于自定义reward function,对应的processing class被忽略。
4. 计算rewards
- 在
_calculate_rewards
方法中,遍历所有reward function,计算每个function的reward:rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(...):if isinstance(reward_func, nn.Module):# 对于模型reward function:将prompts和completions组合,tokenize,然后通过模型计算rewardtexts = [p + c for p, c in zip(prompts, completions)] # 或应用chat templatereward_inputs = reward_processing_class(text=texts, padding=True, ...)with torch.inference_mode():rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]else:# 对于自定义reward function:直接调用函数,传入prompts、completions和其他参数output_reward_func = reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs)# 处理None值:转换为NaNoutput_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
- 这里,
reward_kwargs
包括数据集中除"prompt"、“completion”、"completion_ids"外的所有列,以及trainer_state
(训练状态)。
5. 加权求和rewards
- 所有reward function的输出被组合成一个张量
rewards_per_func
(形状为[batch_size, num_reward_funcs]
),然后与权重相乘并求和:rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
- 使用
nansum
是为了处理自定义reward function返回None
(转换为NaN)的情况。如果所有reward function对某个样本都返回NaN,则最终reward为NaN,并会发出警告。
6. 处理NaN情况
- 如果所有reward function对某个样本都返回NaN,会记录警告:
if torch.isnan(rewards_per_func).all(dim=1).any():nan_row_idx = ... # 找到第一个全NaN的行logger.warning("All reward functions returned None for the following kwargs: ...")
7. 日志记录
- 每个reward function的mean和std被记录到metrics中,便于监控:
for i, reward_func_name in enumerate(self.reward_func_names):mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)std_func_rewards = nanstd(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
- 这些日志在训练和评估时都会更新。
8. 优势计算和归一化
- 加权后的rewards用于计算优势(advantages)。rewards首先按prompt分组(每组包含
num_generations
个completion),计算组内mean,然后归一化:mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = rewards - mean_grouped_rewards
- 优势可以按组或批量缩放(通过
scale_rewards
参数控制)。
总结
多个reward function的处理是灵活的:支持混合类型(模型和自定义函数),允许自定义权重,并处理不适用情况(返回None)。最终reward是加权和,同时每个function的贡献被单独记录用于监控。这种设计适用于多任务学习,其中不同reward函数可能应用于不同类型样本。