GRPO(Group Relative Policy Optimization)公式速览
GRPO(Group Relative Policy Optimization)公式速览
把 1600 行源码浓缩成一页可抄进论文的公式表。
1 组内归一化优势(Group-Relative Advantage)
符号 | 含义 | 代码变量 |
---|---|---|
q | prompt | prompt |
G | 组大小 | num_generations |
o_i | 第 i 条 completion | completions[i] |
r(q,o_i) | 奖励 | rewards[i] |
μr(q)=1G∑i=1Gr(q,oi)σr(q)=stdi=1..G r(q,oi)Aq,oi=r(q,oi)−μr(q)σr(q)+ε(ε=1×10−4) \boxed{ \begin{aligned} \mu_r(q) &= \frac{1}{G}\sum_{i=1}^{G} r(q, o_i) \\ \sigma_r(q) &= \text{std}_{i=1..G}\, r(q, o_i) \\ A_{q,o_i} &= \frac{r(q, o_i) - \mu_r(q)}{\sigma_r(q) + \varepsilon} \quad (\varepsilon = 1\times10^{-4}) \end{aligned} } μr(q)σr(q)Aq,oi=G1i=1∑Gr(q,oi)=stdi=1..Gr(q,oi)=σr(q)+εr(q,oi)−μr(q)(ε=1×10−4)
2 策略裁剪目标(Token-Level)
ri,t(θ)=πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t)Lclip(θ)=∑i=1G∑t=1∣oi∣min (ri,t(θ) Aq,oi, clip(ri,t(θ), 1−εlow, 1+εhigh) Aq,oi) \boxed{ \begin{aligned} r_{i,t}(\theta) &= \frac{\pi_\theta(o_{i,t}\mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}\mid q, o_{i,<t})} \\[6pt] L_{\text{clip}}(\theta) &= \sum_{i=1}^{G}\sum_{t=1}^{|o_i|} \min\!\Bigl(r_{i,t}(\theta)\,A_{q,o_i},\;\text{clip}(r_{i,t}(\theta),\,1-\varepsilon_{\text{low}},\,1+\varepsilon_{\text{high}})\,A_{q,o_i} \Bigr) \end{aligned} } ri,t(θ)Lclip(θ)=πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)=i=1∑Gt=1∑∣oi∣min(ri,t(θ)Aq,oi,clip(ri,t(θ),1−εlow,1+εhigh)Aq,oi)
代码对应:per_token_logps
, old_per_token_logps
, coef_1
, coef_2
3 KL 正则项(可选,β>0 时启用
KLreg=β DKL [πθ ∥ πref]\boxed{ \mathrm{KL}_{\mathrm{reg}} = \beta \; D_{\mathrm{KL}}\!\bigl[\pi_{\theta}\,\|\,\pi_{\mathrm{ref}}\bigr]}KLreg=βDKL[πθ∥πref]
代码对应:per_token_kl
, beta
4 最终损失(Token-Level)
| GRPO | Ltotal=−Lclip∑t1+KLreg\displaystyle L_{\text{total}} = -\frac{L_{\text{clip}}}{\sum_t 1} + \text{KL}_{\text{reg}}Ltotal=−∑t1Lclip+KLreg
| BNPO | Ltotal=−Lclip∑t1+KLreg\displaystyle L_{\text{total}} = -\frac{L_{\text{clip}}}{\sum_t 1} + \text{KL}_{\text{reg}}Ltotal=−∑t1Lclip+KLreg
| DR-GRPO | Ltotal=−LclipB⋅T+KLreg\displaystyle L_{\text{total}} = -\frac{L_{\text{clip}}}{B\cdot T} + \text{KL}_{\text{reg}}Ltotal=−B⋅TLclip+KLreg
- 代码由
loss_type
参数切换。
**BNPO vs GRPO:一句话速记 **
BNPO = GRPO 的“奖励归一化外挂”
二者共用同一套“组内相对优势 + KL + clip”框架,只是 BNPO 把静态均值-方差换成了动态 Beta 归一化。
✅ 核心差别表
维度 | GRPO | BNPO |
---|---|---|
归一化方式 | 组内均值-方差(静态) | Beta 分布自适应 |
奖励假设 | 任何数值 | 二值奖励 Bernoulli |
基线更新 | 每次 batch 重算 μ, σ | 实时更新 α, β 参数 |
梯度方差 | 固定 | 随策略动态减小 |
是否 GRPO 的超集 | ✗ | 是:GRPO 是 β 固定时的特例 |
✅ 公式对照(一句话看懂)
算法 | 优势函数 |
---|---|
GRPO | AGRPO=r−μrσr+εA_{\text{GRPO}} = \frac{r - \mu_r}{\sigma_r+\varepsilon}AGRPO=σr+εr−μr |
BNPO | ABNPO=r−μ^μ^(1−μ^)+δA_{\text{BNPO}} = \frac{r - \hat{\mu}}{\sqrt{\hat{\mu}(1-\hat{\mu})+\delta}}ABNPO=μ^(1−μ^)+δr−μ^ 其中 μ^∼Beta(α,β)\hat{\mu}\sim\text{Beta}(\alpha,\beta)μ^∼Beta(α,β),参数 α,β 用最近 N 步的奖励在线估计。 |
✅ 使用场景
- GRPO:通用、简单、快速实现。
- BNPO:
- 奖励只有 0/1(正确/错误)
- 训练初期奖励分布漂移大(如数学推理任务)
- 需要更低梯度方差、更高稳定性
✅ 结论
- BNPO 不推翻 GRPO,只是把“静态均值基线”升级为“动态 Beta 基线”。
- 在 TRL 中只需把
loss_type='bnpo'
即可启用,其余流程(采样、clip、KL)完全一致。
Huggingface TRL中是怎么实现的
计算reward
def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):device = self.accelerator.devicerewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)# Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generationskeys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]reward_kwargs = {key: [example[key] for example in inputs] for key in keys}# This allows for dynamic reward shaping based on training progress.reward_kwargs["trainer_state"] = self.statefor i, (reward_func, reward_processing_class, reward_func_name) in enumerate(zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)):with profiling_context(self, reward_func_name):if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled modelsif is_conversational(inputs[0]):messages = [{"messages": p + c} for p, c in zip(prompts, completions)]texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]else:texts = [p + c for p, c in zip(prompts, completions)]reward_inputs = reward_processing_class(text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False)reward_inputs = super()._prepare_inputs(reward_inputs)with torch.inference_mode():rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)else:output_reward_func = reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs)# Convert None values to NaNoutput_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)# If all reward functions return None for a given row, issue a detailed warningif torch.isnan(rewards_per_func).all(dim=1).any():nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}row_reward_kwargs["prompt"] = prompts[nan_row_idx]row_reward_kwargs["completion"] = completions[nan_row_idx]warnings.warn(f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. ""Please ensure that at least one reward function returns a valid reward.")# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the# completions may be distributed across processesrewards_per_func = gather(rewards_per_func)return rewards_per_func
_calculate_rewards
把“多条 completion”喂给“多个奖励函数”,返回一张 (B×G, F)
的奖励矩阵,并跨进程同步,为后续 GRPO 组内归一化做准备。
✅ 输入输出
参数 | 形状 / 含义 |
---|---|
prompts | List[str] 或 List[Messages] ,长度 = B×G |
completions | List[str] 或 List[Messages] ,长度 = B×G |
completion_ids_list | List[List[int]] ,token id,长度 = B×G |
返回值 | Tensor 形状 (B×G, F) ,F = 奖励函数个数 |
✅ 核心步骤
-
初始化容器
rewards_per_func = zeros(B×G, F)
先占好位置。 -
把额外列打包成 kwargs
任何inputs[0]
里除"prompt"/"completion"/"completion_ids"
以外的字段全部按行重复,供自定义奖励函数使用。 -
遍历 F 个奖励函数
- 如果是模型(
nn.Module
):- 构造
prompt+completion
的文本 → 走 tokenizer → 前向 → 取logits[:, 0]
作为标量奖励。
- 构造
- 如果是函数(
Callable
):- 直接调用,允许返回
None
→ 转成NaN
占位。
- 直接调用,允许返回
- 如果是模型(
-
跨进程同步
gather(rewards_per_func)
让 所有 GPU 拿到 全局(N×G, F)
奖励矩阵,保证后续组内归一化一致。 -
异常检测
如果某一行全是NaN
,打印详细 warning,方便排查奖励函数漏返回值。
✅ 总结
“把 B×G 条 completion 喂给 F 个奖励函数,跨进程收集结果,生成 (B×G, F)
的奖励张量,供 GRPO 做组内归一化。”
_generate_and_score_completions
def _generate_and_score_completions(self, inputs: list[dict[str, Union[torch.Tensor, Any]]]) -> dict[str, Union[torch.Tensor, Any]]:device = self.accelerator.devicemode = "train" if self.model.training else "eval"prompts = [x["prompt"] for x in inputs]# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the# VLM chat template.original_prompts = copy.deepcopy(prompts)# If the prompts are conversational and the inputs contain images, we need to convert the prompts from# [{"role": "user", "content": "What color is the sky?"}] to# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]kwargs = {}has_images = "image" in inputs[0]if has_images:images = [example.get("image") for example in inputs]kwargs = {"images": [[img] for img in images]}for prompt in prompts:if isinstance(prompt, list):for message in prompt:if not isinstance(message, dict):continuecontent = message.get("content")role = message.get("role")if isinstance(content, str):if role == "user":message["content"] = [{"type": "image"}, {"type": "text", "text": content}]elif role == "system":message["content"] = [{"type": "text", "text": content}]prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]prompt_inputs = self.processing_class(text=prompts_text,return_tensors="pt",padding=True,padding_side="left",add_special_tokens=False,**kwargs,)prompt_inputs = super()._prepare_inputs(prompt_inputs)prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]if self.max_prompt_length is not None:# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]protected = [token for token in protected if token is not None]prompt_ids, prompt_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, self.max_prompt_length, protected)prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]# The chat template sometimes inserts a single image token into the prompt text. However, when this text is# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We# collapse them back into a single token string to match the original chat template in case it originally# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only# the vision_start_token_id (e.g. <start_of_image>).if self.image_token is not None:escaped_img_token = re.escape(self.image_token)# Search for the image token in the chat templateif re.search(escaped_img_token, self.processing_class.chat_template):prompts_text = [re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text]else:# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_idif self.vision_end_token_id is not None:escaped_eoi_token = re.escape(self.processing_class.tokenizer.decode([self.vision_end_token_id]))prompts_text = [re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text]else:# If vision_end_token_id is None, just remove the image tokensprompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]# Generate completions using either vLLM or regular generationif self.use_vllm:# First, update the vLLM weights if neededif self.state.global_step != self._last_loaded_step:self._move_model_to_vllm()self._last_loaded_step = self.state.global_step# Generate completions using vLLM: gather all prompts and use them in a single call in the main processif self.vllm_mode == "server":all_prompts_text = gather_object(prompts_text)if has_images:all_images = gather_object(images)if self.accelerator.is_main_process:# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate# num_generations outputs for each one. This is faster than generating outputs for each duplicate# prompt individually.ordered_set_of_prompts = all_prompts_text[:: self.num_generations]if has_images:ordered_set_of_images = all_images[:: self.num_generations]else:ordered_set_of_images = Nonewith profiling_context(self, "vLLM.generate"):completion_ids = self.vllm_client.generate(prompts=ordered_set_of_prompts,images=ordered_set_of_images,n=self.num_generations,repetition_penalty=self.repetition_penalty,temperature=self.temperature,top_p=self.top_p,top_k=-1 if self.top_k is None else self.top_k,min_p=0.0 if self.min_p is None else self.min_p,max_tokens=self.max_completion_length,guided_decoding_regex=self.guided_decoding_regex,generation_kwargs=self.args.generation_kwargs,)else:completion_ids = [None] * len(all_prompts_text)# Broadcast the completions from the main process to all processes, ensuring each process receives its# corresponding slice.completion_ids = broadcast_object_list(completion_ids, from_process=0)process_slice = slice(self.accelerator.process_index * len(prompts),(self.accelerator.process_index + 1) * len(prompts),)completion_ids = completion_ids[process_slice]# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of promptselif self.vllm_mode == "colocate":if self.guided_decoding_regex:guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)else:guided_decoding = Nonegeneration_kwargs = {"n": 1, # vLLM on each GPU generates only 1 in colocate mode"repetition_penalty": self.repetition_penalty,"temperature": self.temperature,"top_p": self.top_p,"top_k": -1 if self.top_k is None else self.top_k,"min_p": 0.0 if self.min_p is None else self.min_p,"max_tokens": self.max_completion_length,"guided_decoding": guided_decoding,}if self.args.generation_kwargs is not None:generation_kwargs.update(self.args.generation_kwargs)sampling_params = SamplingParams(**generation_kwargs)if self.vllm_tensor_parallel_size > 1:# Gather prompts from all ranks in the TP group and flatten.# Each rank starts with its own prompts; after gathering, all ranks see the full group set.orig_size = len(prompts_text)gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)all_prompts_text = [p for sublist in gathered_prompts for p in sublist]if has_images:gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)all_images = [img for sublist in gathered_images for img in sublist]else:all_images = Noneelse:all_prompts_text = prompts_textall_images = images if has_images else Noneif has_images and all_images:vllm_inputs = []for prompt, image in zip(all_prompts_text, all_images):if image is not None:vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})else:vllm_inputs.append(prompt)else:vllm_inputs = all_prompts_textwith profiling_context(self, "vLLM.generate"):all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]if self.vllm_tensor_parallel_size > 1:# Slice completions for this rank within its TP group.# Each rank generates all outputs — we keep only our share.local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)completion_ids = completion_ids[tp_slice]# Pad the completions, and concatenate them with the promptscompletion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]completion_ids = pad(completion_ids, padding_value=self.pad_token_id)prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)elif self.use_transformers_paged:# Re-process inputs for paged generation if needed# Note: images are already validated and preprocessed abovepaged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)previous_attn = self.model_wrapped.config._attn_implementationif is_flash_attn_2_available():self.model_wrapped.config._attn_implementation = "paged_attention"else:self.model_wrapped.config._attn_implementation = "sdpa_paged"with (profiling_context(self, "transformers.generate_batch"),unwrap_model_for_generation(self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model,torch.no_grad(),FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),):# Cast to the appropriate dtype based on training configurationif self.args.bf16:unwrapped_model.to(torch.bfloat16)elif self.args.fp16:unwrapped_model.to(torch.float16)with torch.inference_mode():all_outputs = unwrapped_model.generate_batch(paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False)completion_ids = [output.generated_tokens for output in all_outputs.values()]completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)# Restore the original attention implementation, training modeself.model_wrapped.config._attn_implementation = previous_attnelse:# Regular generation pathwith (profiling_context(self, "transformers.generate"),unwrap_model_for_generation(self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model,torch.no_grad(),FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),):prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_maskprompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config, disable_compile=True)# Compute prompt length and extract completion idsprompt_length = prompt_ids.size(1)prompt_ids = prompt_completion_ids[:, :prompt_length]completion_ids = prompt_completion_ids[:, prompt_length:]# Mask everything after the first EOS tokenis_eos = completion_ids == self.eos_token_ideos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need# to re-tokenize completions if the reward is computed from tokens.completion_ids_list = [[id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)]# Sum along sequence dimension (dim=1) to get completion length per sequence, used for loggingcompletion_lengths = completion_mask.sum(1)# If mask_truncated_completions is enabled, zero out truncated completions in completion_maskif self.mask_truncated_completions:truncated_completions = ~is_eos.any(dim=1)completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()# Concatenate prompt_mask with completion_mask for logit computationattention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokensbatch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_sizewith torch.no_grad():# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set# old_per_token_logps to None.generate_every = self.args.steps_per_generation * self.num_iterations # generation frequencyif self.args.gradient_accumulation_steps % generate_every != 0:old_per_token_logps, _ = self._get_per_token_logps_and_entropies(self.model,prompt_completion_ids,attention_mask,logits_to_keep,batch_size,pixel_values=prompt_inputs.get("pixel_values"),image_grid_thw=prompt_inputs.get("image_grid_thw"),pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),image_sizes=prompt_inputs.get("image_sizes"),)else:old_per_token_logps = None# Compute the per-token log probabilities for the reference modelif self.beta != 0.0:if self.ref_model is not None:ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(self.ref_model,prompt_completion_ids,attention_mask,logits_to_keep,batch_size=batch_size,pixel_values=prompt_inputs.get("pixel_values"),image_grid_thw=prompt_inputs.get("image_grid_thw"),pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),image_sizes=prompt_inputs.get("image_sizes"),)else:with self.accelerator.unwrap_model(self.model).disable_adapter():ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(self.model,prompt_completion_ids,attention_mask,logits_to_keep,batch_size=batch_size,pixel_values=prompt_inputs.get("pixel_values"),image_grid_thw=prompt_inputs.get("image_grid_thw"),pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),image_sizes=prompt_inputs.get("image_sizes"),)else:ref_per_token_logps = None# Decode the generated completionscompletions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)if is_conversational(inputs[0]):completions = []for prompt, completion in zip(prompts, completions_text):bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""completions.append([{"role": "assistant", "content": bootstrap + completion}])else:completions = completions_text# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is# important because rewards will be normalized per group, and completions are distributed. We will later slice# rewards_per_func to extract each process's subset.rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)# Apply weights to each reward function's output and sumrewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)# Compute grouped-wise rewardsmean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))# Normalize the rewards to compute the advantagesmean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)advantages = rewards - mean_grouped_rewardsif self.scale_rewards:advantages = advantages / (std_grouped_rewards + 1e-4)# Slice to keep only the local part of the dataprocess_slice = slice(self.accelerator.process_index * len(prompts),(self.accelerator.process_index + 1) * len(prompts),)all_process_advantages = advantages.clone() # keep the aggregated advantages for loggingadvantages = advantages[process_slice]# Log the metricsif mode == "train":self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]# Log completion lengths, mean, min, maxagg_completion_lengths = self.accelerator.gather(completion_lengths)self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())# Identify sequences that terminated with EOS and log their lengthsagg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)if len(term_completion_lengths) == 0: # edge case where no terminated sequences are foundterm_completion_lengths = torch.zeros(1, device=device)self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)for i, reward_func_name in enumerate(self.reward_func_names):mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)std_rewards = nanstd(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())# Log prompt and completion textsself._logs["prompt"].extend(gather_object(prompts_text))self._logs["completion"].extend(gather_object(completions_text))for i, name in enumerate(self.reward_func_names):self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())self._logs["advantages"].extend(all_process_advantages.tolist())if has_images:self._logs["image"].extend(gather_object(images))output = {"prompt_ids": prompt_ids,"prompt_mask": prompt_mask,"completion_ids": completion_ids,"completion_mask": completion_mask,"advantages": advantages,}if old_per_token_logps is not None:output["old_per_token_logps"] = old_per_token_logpsif ref_per_token_logps is not None:output["ref_per_token_logps"] = ref_per_token_logpsif "pixel_values" in prompt_inputs:output["pixel_values"] = prompt_inputs["pixel_values"]if "image_grid_thw" in prompt_inputs:output["image_grid_thw"] = prompt_inputs["image_grid_thw"]if "pixel_attention_mask" in prompt_inputs:output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]if "image_sizes" in prompt_inputs:output["image_sizes"] = prompt_inputs["image_sizes"]return output
_generate_and_score_completions
是 GRPOTrainer 的“心脏”——
一次性完成 prompt 处理 → 多后端生成 → 奖励打分 → 组内归一化 → 输出训练所需全部张量。
✅ 一句话总结
“把一批 prompt 变成 B×G 条 completion,奖励打分后算组内优势,打包成可直接喂给损失函数的训练字典。”
✅ 核心流程(8 步速记)
步骤 | 关键动作 | 代码/变量 |
---|---|---|
1️⃣ 输入准备 | 提取 prompt、处理图文 | prompts , has_images |
2️⃣ token 化 | 左填充、截断、保护特殊 token | truncate_with_protected_tokens |
3️⃣ 生成 | vLLM / transformers / paged-attention 三选一 | completion_ids |
4️⃣ 后处理 | 截断 EOS、生成 mask | completion_mask , completion_lengths |
5️⃣ 奖励打分 | 调用 _calculate_rewards | rewards_per_func |
6️⃣ 加权求和 | 多奖励函数加权 → 单条奖励 | rewards |
7️⃣ 组内归一化 | 均值-方差 → 优势 | advantages |
8️⃣ 跨进程同步 | gather & slice 保证分布式一致 | gather , process_slice |
✅ 输出字典(可直接喂损失)
{"prompt_ids" : Tensor, # (B×G, P)"prompt_mask" : Tensor, # (B×G, P)"completion_ids" : Tensor, # (B×G, C)"completion_mask" : Tensor, # (B×G, C)"advantages" : Tensor, # (B×G,) 组内归一化优势"old_per_token_logps" : Tensor, # 可选,重要性采样"ref_per_token_logps" : Tensor, # 可选,KL 计算... # 图像相关字段(若多模态)
}
✅ 再总结
只要调用一次 _generate_and_score_completions
,就能把“prompt”变成“带优势的训练样本”。