当前位置: 首页 > news >正文

GRPO(Group Relative Policy Optimization)公式速览

GRPO(Group Relative Policy Optimization)公式速览

把 1600 行源码浓缩成一页可抄进论文的公式表。


1 组内归一化优势(Group-Relative Advantage)

符号含义代码变量
qpromptprompt
G组大小num_generations
o_i第 i 条 completioncompletions[i]
r(q,o_i)奖励rewards[i]

μr(q)=1G∑i=1Gr(q,oi)σr(q)=stdi=1..G r(q,oi)Aq,oi=r(q,oi)−μr(q)σr(q)+ε(ε=1×10−4) \boxed{ \begin{aligned} \mu_r(q) &= \frac{1}{G}\sum_{i=1}^{G} r(q, o_i) \\ \sigma_r(q) &= \text{std}_{i=1..G}\, r(q, o_i) \\ A_{q,o_i} &= \frac{r(q, o_i) - \mu_r(q)}{\sigma_r(q) + \varepsilon} \quad (\varepsilon = 1\times10^{-4}) \end{aligned} } μr(q)σr(q)Aq,oi=G1i=1Gr(q,oi)=stdi=1..Gr(q,oi)=σr(q)+εr(q,oi)μr(q)(ε=1×104)


2 策略裁剪目标(Token-Level)

ri,t(θ)=πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t)Lclip(θ)=∑i=1G∑t=1∣oi∣min⁡ ⁣(ri,t(θ) Aq,oi,  clip(ri,t(θ), 1−εlow, 1+εhigh) Aq,oi) \boxed{ \begin{aligned} r_{i,t}(\theta) &= \frac{\pi_\theta(o_{i,t}\mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}\mid q, o_{i,<t})} \\[6pt] L_{\text{clip}}(\theta) &= \sum_{i=1}^{G}\sum_{t=1}^{|o_i|} \min\!\Bigl(r_{i,t}(\theta)\,A_{q,o_i},\;\text{clip}(r_{i,t}(\theta),\,1-\varepsilon_{\text{low}},\,1+\varepsilon_{\text{high}})\,A_{q,o_i} \Bigr) \end{aligned} } ri,t(θ)Lclip(θ)=πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)=i=1Gt=1oimin(ri,t(θ)Aq,oi,clip(ri,t(θ),1εlow,1+εhigh)Aq,oi)

代码对应:per_token_logps, old_per_token_logps, coef_1, coef_2


3 KL 正则项(可选,β>0 时启用

KLreg=β  DKL ⁣[πθ ∥ πref]\boxed{ \mathrm{KL}_{\mathrm{reg}} = \beta \; D_{\mathrm{KL}}\!\bigl[\pi_{\theta}\,\|\,\pi_{\mathrm{ref}}\bigr]}KLreg=βDKL[πθπref]

代码对应:per_token_kl, beta


4 最终损失(Token-Level)

| GRPO | Ltotal=−Lclip∑t1+KLreg\displaystyle L_{\text{total}} = -\frac{L_{\text{clip}}}{\sum_t 1} + \text{KL}_{\text{reg}}Ltotal=t1Lclip+KLreg
| BNPO | Ltotal=−Lclip∑t1+KLreg\displaystyle L_{\text{total}} = -\frac{L_{\text{clip}}}{\sum_t 1} + \text{KL}_{\text{reg}}Ltotal=t1Lclip+KLreg
| DR-GRPO | Ltotal=−LclipB⋅T+KLreg\displaystyle L_{\text{total}} = -\frac{L_{\text{clip}}}{B\cdot T} + \text{KL}_{\text{reg}}Ltotal=BTLclip+KLreg

  • 代码由 loss_type 参数切换。

**BNPO vs GRPO:一句话速记 **

BNPO = GRPO 的“奖励归一化外挂”
二者共用同一套“组内相对优势 + KL + clip”框架,只是 BNPO 把静态均值-方差换成了动态 Beta 归一化


✅ 核心差别表

维度GRPOBNPO
归一化方式组内均值-方差(静态)Beta 分布自适应
奖励假设任何数值二值奖励 Bernoulli
基线更新每次 batch 重算 μ, σ实时更新 α, β 参数
梯度方差固定随策略动态减小
是否 GRPO 的超集是:GRPO 是 β 固定时的特例

✅ 公式对照(一句话看懂)

算法优势函数
GRPOAGRPO=r−μrσr+εA_{\text{GRPO}} = \frac{r - \mu_r}{\sigma_r+\varepsilon}AGRPO=σr+εrμr
BNPOABNPO=r−μ^μ^(1−μ^)+δA_{\text{BNPO}} = \frac{r - \hat{\mu}}{\sqrt{\hat{\mu}(1-\hat{\mu})+\delta}}ABNPO=μ^(1μ^)+δrμ^
其中 μ^∼Beta(α,β)\hat{\mu}\sim\text{Beta}(\alpha,\beta)μ^Beta(α,β),参数 α,β 用最近 N 步的奖励在线估计。

✅ 使用场景

  • GRPO:通用、简单、快速实现。
  • BNPO
    • 奖励只有 0/1(正确/错误)
    • 训练初期奖励分布漂移大(如数学推理任务)
    • 需要更低梯度方差、更高稳定性

✅ 结论

  • BNPO 不推翻 GRPO,只是把“静态均值基线”升级为“动态 Beta 基线”。
  • 在 TRL 中只需把 loss_type='bnpo' 即可启用,其余流程(采样、clip、KL)完全一致。

Huggingface TRL中是怎么实现的

计算reward

def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):device = self.accelerator.devicerewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)# Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generationskeys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]reward_kwargs = {key: [example[key] for example in inputs] for key in keys}# This allows for dynamic reward shaping based on training progress.reward_kwargs["trainer_state"] = self.statefor i, (reward_func, reward_processing_class, reward_func_name) in enumerate(zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)):with profiling_context(self, reward_func_name):if isinstance(reward_func, nn.Module):  # Module (no PretrainedModel) for compat with compiled modelsif is_conversational(inputs[0]):messages = [{"messages": p + c} for p, c in zip(prompts, completions)]texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]else:texts = [p + c for p, c in zip(prompts, completions)]reward_inputs = reward_processing_class(text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False)reward_inputs = super()._prepare_inputs(reward_inputs)with torch.inference_mode():rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)else:output_reward_func = reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs)# Convert None values to NaNoutput_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)# If all reward functions return None for a given row, issue a detailed warningif torch.isnan(rewards_per_func).all(dim=1).any():nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}row_reward_kwargs["prompt"] = prompts[nan_row_idx]row_reward_kwargs["completion"] = completions[nan_row_idx]warnings.warn(f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. ""Please ensure that at least one reward function returns a valid reward.")# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the# completions may be distributed across processesrewards_per_func = gather(rewards_per_func)return rewards_per_func

_calculate_rewards 把“多条 completion”喂给“多个奖励函数”,返回一张 (B×G, F) 的奖励矩阵,并跨进程同步,为后续 GRPO 组内归一化做准备。


✅ 输入输出

参数形状 / 含义
promptsList[str]List[Messages],长度 = B×G
completionsList[str]List[Messages],长度 = B×G
completion_ids_listList[List[int]],token id,长度 = B×G
返回值Tensor 形状 (B×G, F),F = 奖励函数个数

✅ 核心步骤

  1. 初始化容器
    rewards_per_func = zeros(B×G, F) 先占好位置。

  2. 把额外列打包成 kwargs
    任何 inputs[0] 里除 "prompt"/"completion"/"completion_ids" 以外的字段全部按行重复,供自定义奖励函数使用。

  3. 遍历 F 个奖励函数

    • 如果是模型nn.Module):
      • 构造 prompt+completion 的文本 → 走 tokenizer → 前向 → 取 logits[:, 0] 作为标量奖励。
    • 如果是函数Callable):
      • 直接调用,允许返回 None → 转成 NaN 占位。
  4. 跨进程同步
    gather(rewards_per_func)所有 GPU 拿到 全局 (N×G, F) 奖励矩阵,保证后续组内归一化一致。

  5. 异常检测
    如果某一行全是 NaN,打印详细 warning,方便排查奖励函数漏返回值。


✅ 总结

“把 B×G 条 completion 喂给 F 个奖励函数,跨进程收集结果,生成 (B×G, F) 的奖励张量,供 GRPO 做组内归一化。”

_generate_and_score_completions

    def _generate_and_score_completions(self, inputs: list[dict[str, Union[torch.Tensor, Any]]]) -> dict[str, Union[torch.Tensor, Any]]:device = self.accelerator.devicemode = "train" if self.model.training else "eval"prompts = [x["prompt"] for x in inputs]# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the# VLM chat template.original_prompts = copy.deepcopy(prompts)# If the prompts are conversational and the inputs contain images, we need to convert the prompts from# [{"role": "user", "content": "What color is the sky?"}] to# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]kwargs = {}has_images = "image" in inputs[0]if has_images:images = [example.get("image") for example in inputs]kwargs = {"images": [[img] for img in images]}for prompt in prompts:if isinstance(prompt, list):for message in prompt:if not isinstance(message, dict):continuecontent = message.get("content")role = message.get("role")if isinstance(content, str):if role == "user":message["content"] = [{"type": "image"}, {"type": "text", "text": content}]elif role == "system":message["content"] = [{"type": "text", "text": content}]prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]prompt_inputs = self.processing_class(text=prompts_text,return_tensors="pt",padding=True,padding_side="left",add_special_tokens=False,**kwargs,)prompt_inputs = super()._prepare_inputs(prompt_inputs)prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]if self.max_prompt_length is not None:# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]protected = [token for token in protected if token is not None]prompt_ids, prompt_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, self.max_prompt_length, protected)prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]# The chat template sometimes inserts a single image token into the prompt text. However, when this text is# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We# collapse them back into a single token string to match the original chat template in case it originally# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only# the vision_start_token_id (e.g. <start_of_image>).if self.image_token is not None:escaped_img_token = re.escape(self.image_token)# Search for the image token in the chat templateif re.search(escaped_img_token, self.processing_class.chat_template):prompts_text = [re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text]else:# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_idif self.vision_end_token_id is not None:escaped_eoi_token = re.escape(self.processing_class.tokenizer.decode([self.vision_end_token_id]))prompts_text = [re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text]else:# If vision_end_token_id is None, just remove the image tokensprompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]# Generate completions using either vLLM or regular generationif self.use_vllm:# First, update the vLLM weights if neededif self.state.global_step != self._last_loaded_step:self._move_model_to_vllm()self._last_loaded_step = self.state.global_step# Generate completions using vLLM: gather all prompts and use them in a single call in the main processif self.vllm_mode == "server":all_prompts_text = gather_object(prompts_text)if has_images:all_images = gather_object(images)if self.accelerator.is_main_process:# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate# num_generations outputs for each one. This is faster than generating outputs for each duplicate# prompt individually.ordered_set_of_prompts = all_prompts_text[:: self.num_generations]if has_images:ordered_set_of_images = all_images[:: self.num_generations]else:ordered_set_of_images = Nonewith profiling_context(self, "vLLM.generate"):completion_ids = self.vllm_client.generate(prompts=ordered_set_of_prompts,images=ordered_set_of_images,n=self.num_generations,repetition_penalty=self.repetition_penalty,temperature=self.temperature,top_p=self.top_p,top_k=-1 if self.top_k is None else self.top_k,min_p=0.0 if self.min_p is None else self.min_p,max_tokens=self.max_completion_length,guided_decoding_regex=self.guided_decoding_regex,generation_kwargs=self.args.generation_kwargs,)else:completion_ids = [None] * len(all_prompts_text)# Broadcast the completions from the main process to all processes, ensuring each process receives its# corresponding slice.completion_ids = broadcast_object_list(completion_ids, from_process=0)process_slice = slice(self.accelerator.process_index * len(prompts),(self.accelerator.process_index + 1) * len(prompts),)completion_ids = completion_ids[process_slice]# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of promptselif self.vllm_mode == "colocate":if self.guided_decoding_regex:guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)else:guided_decoding = Nonegeneration_kwargs = {"n": 1,  # vLLM on each GPU generates only 1 in colocate mode"repetition_penalty": self.repetition_penalty,"temperature": self.temperature,"top_p": self.top_p,"top_k": -1 if self.top_k is None else self.top_k,"min_p": 0.0 if self.min_p is None else self.min_p,"max_tokens": self.max_completion_length,"guided_decoding": guided_decoding,}if self.args.generation_kwargs is not None:generation_kwargs.update(self.args.generation_kwargs)sampling_params = SamplingParams(**generation_kwargs)if self.vllm_tensor_parallel_size > 1:# Gather prompts from all ranks in the TP group and flatten.# Each rank starts with its own prompts; after gathering, all ranks see the full group set.orig_size = len(prompts_text)gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)all_prompts_text = [p for sublist in gathered_prompts for p in sublist]if has_images:gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)all_images = [img for sublist in gathered_images for img in sublist]else:all_images = Noneelse:all_prompts_text = prompts_textall_images = images if has_images else Noneif has_images and all_images:vllm_inputs = []for prompt, image in zip(all_prompts_text, all_images):if image is not None:vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})else:vllm_inputs.append(prompt)else:vllm_inputs = all_prompts_textwith profiling_context(self, "vLLM.generate"):all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]if self.vllm_tensor_parallel_size > 1:# Slice completions for this rank within its TP group.# Each rank generates all outputs — we keep only our share.local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)completion_ids = completion_ids[tp_slice]# Pad the completions, and concatenate them with the promptscompletion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]completion_ids = pad(completion_ids, padding_value=self.pad_token_id)prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)elif self.use_transformers_paged:# Re-process inputs for paged generation if needed# Note: images are already validated and preprocessed abovepaged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)previous_attn = self.model_wrapped.config._attn_implementationif is_flash_attn_2_available():self.model_wrapped.config._attn_implementation = "paged_attention"else:self.model_wrapped.config._attn_implementation = "sdpa_paged"with (profiling_context(self, "transformers.generate_batch"),unwrap_model_for_generation(self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model,torch.no_grad(),FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),):# Cast to the appropriate dtype based on training configurationif self.args.bf16:unwrapped_model.to(torch.bfloat16)elif self.args.fp16:unwrapped_model.to(torch.float16)with torch.inference_mode():all_outputs = unwrapped_model.generate_batch(paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False)completion_ids = [output.generated_tokens for output in all_outputs.values()]completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)# Restore the original attention implementation, training modeself.model_wrapped.config._attn_implementation = previous_attnelse:# Regular generation pathwith (profiling_context(self, "transformers.generate"),unwrap_model_for_generation(self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model,torch.no_grad(),FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),):prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_maskprompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config, disable_compile=True)# Compute prompt length and extract completion idsprompt_length = prompt_ids.size(1)prompt_ids = prompt_completion_ids[:, :prompt_length]completion_ids = prompt_completion_ids[:, prompt_length:]# Mask everything after the first EOS tokenis_eos = completion_ids == self.eos_token_ideos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need# to re-tokenize completions if the reward is computed from tokens.completion_ids_list = [[id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)]# Sum along sequence dimension (dim=1) to get completion length per sequence, used for loggingcompletion_lengths = completion_mask.sum(1)# If mask_truncated_completions is enabled, zero out truncated completions in completion_maskif self.mask_truncated_completions:truncated_completions = ~is_eos.any(dim=1)completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()# Concatenate prompt_mask with completion_mask for logit computationattention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokensbatch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_sizewith torch.no_grad():# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set# old_per_token_logps to None.generate_every = self.args.steps_per_generation * self.num_iterations  # generation frequencyif self.args.gradient_accumulation_steps % generate_every != 0:old_per_token_logps, _ = self._get_per_token_logps_and_entropies(self.model,prompt_completion_ids,attention_mask,logits_to_keep,batch_size,pixel_values=prompt_inputs.get("pixel_values"),image_grid_thw=prompt_inputs.get("image_grid_thw"),pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),image_sizes=prompt_inputs.get("image_sizes"),)else:old_per_token_logps = None# Compute the per-token log probabilities for the reference modelif self.beta != 0.0:if self.ref_model is not None:ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(self.ref_model,prompt_completion_ids,attention_mask,logits_to_keep,batch_size=batch_size,pixel_values=prompt_inputs.get("pixel_values"),image_grid_thw=prompt_inputs.get("image_grid_thw"),pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),image_sizes=prompt_inputs.get("image_sizes"),)else:with self.accelerator.unwrap_model(self.model).disable_adapter():ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(self.model,prompt_completion_ids,attention_mask,logits_to_keep,batch_size=batch_size,pixel_values=prompt_inputs.get("pixel_values"),image_grid_thw=prompt_inputs.get("image_grid_thw"),pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),image_sizes=prompt_inputs.get("image_sizes"),)else:ref_per_token_logps = None# Decode the generated completionscompletions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)if is_conversational(inputs[0]):completions = []for prompt, completion in zip(prompts, completions_text):bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""completions.append([{"role": "assistant", "content": bootstrap + completion}])else:completions = completions_text# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is# important because rewards will be normalized per group, and completions are distributed. We will later slice# rewards_per_func to extract each process's subset.rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)# Apply weights to each reward function's output and sumrewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)# Compute grouped-wise rewardsmean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))# Normalize the rewards to compute the advantagesmean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)advantages = rewards - mean_grouped_rewardsif self.scale_rewards:advantages = advantages / (std_grouped_rewards + 1e-4)# Slice to keep only the local part of the dataprocess_slice = slice(self.accelerator.process_index * len(prompts),(self.accelerator.process_index + 1) * len(prompts),)all_process_advantages = advantages.clone()  # keep the aggregated advantages for loggingadvantages = advantages[process_slice]# Log the metricsif mode == "train":self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]# Log completion lengths, mean, min, maxagg_completion_lengths = self.accelerator.gather(completion_lengths)self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())# Identify sequences that terminated with EOS and log their lengthsagg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)if len(term_completion_lengths) == 0:  # edge case where no terminated sequences are foundterm_completion_lengths = torch.zeros(1, device=device)self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)for i, reward_func_name in enumerate(self.reward_func_names):mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)std_rewards = nanstd(rewards_per_func[:, i]).item()self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())# Log prompt and completion textsself._logs["prompt"].extend(gather_object(prompts_text))self._logs["completion"].extend(gather_object(completions_text))for i, name in enumerate(self.reward_func_names):self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())self._logs["advantages"].extend(all_process_advantages.tolist())if has_images:self._logs["image"].extend(gather_object(images))output = {"prompt_ids": prompt_ids,"prompt_mask": prompt_mask,"completion_ids": completion_ids,"completion_mask": completion_mask,"advantages": advantages,}if old_per_token_logps is not None:output["old_per_token_logps"] = old_per_token_logpsif ref_per_token_logps is not None:output["ref_per_token_logps"] = ref_per_token_logpsif "pixel_values" in prompt_inputs:output["pixel_values"] = prompt_inputs["pixel_values"]if "image_grid_thw" in prompt_inputs:output["image_grid_thw"] = prompt_inputs["image_grid_thw"]if "pixel_attention_mask" in prompt_inputs:output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]if "image_sizes" in prompt_inputs:output["image_sizes"] = prompt_inputs["image_sizes"]return output

_generate_and_score_completionsGRPOTrainer 的“心脏”——
一次性完成 prompt 处理 → 多后端生成 → 奖励打分 → 组内归一化 → 输出训练所需全部张量。


✅ 一句话总结

“把一批 prompt 变成 B×G 条 completion,奖励打分后算组内优势,打包成可直接喂给损失函数的训练字典。”


✅ 核心流程(8 步速记)

步骤关键动作代码/变量
1️⃣ 输入准备提取 prompt、处理图文prompts, has_images
2️⃣ token 化左填充、截断、保护特殊 tokentruncate_with_protected_tokens
3️⃣ 生成vLLM / transformers / paged-attention 三选一completion_ids
4️⃣ 后处理截断 EOS、生成 maskcompletion_mask, completion_lengths
5️⃣ 奖励打分调用 _calculate_rewardsrewards_per_func
6️⃣ 加权求和多奖励函数加权 → 单条奖励rewards
7️⃣ 组内归一化均值-方差 → 优势advantages
8️⃣ 跨进程同步gather & slice 保证分布式一致gather, process_slice

✅ 输出字典(可直接喂损失)

{"prompt_ids"          : Tensor,  # (B×G, P)"prompt_mask"         : Tensor,  # (B×G, P)"completion_ids"      : Tensor,  # (B×G, C)"completion_mask"     : Tensor,  # (B×G, C)"advantages"          : Tensor,  # (B×G,)  组内归一化优势"old_per_token_logps" : Tensor,  # 可选,重要性采样"ref_per_token_logps" : Tensor,  # 可选,KL 计算...  # 图像相关字段(若多模态)
}

✅ 再总结

只要调用一次 _generate_and_score_completions,就能把“prompt”变成“带优势的训练样本”。

http://www.dtcms.com/a/335297.html

相关文章:

  • Scala面试题及详细答案100道(11-20)-- 函数式编程基础
  • 嵌入式软件架构设计之九: 双机通信之通信方式
  • 排列与组合
  • 超详细yolov8/11-obb旋转框全流程概述:配置环境、数据标注、训练、验证/预测、onnx部署(c++/python)详解
  • STM32标准库学习笔记
  • MM-Spatial和Spatial-MLLM论文解读
  • 【力扣-多数元素 JAVA/Python】
  • CD4+ T细胞激活区分抗PD-L1联合抗CTLA4疗法与单药抗PD-L1治疗的响应差异-空间最近邻分析
  • 民法学学习笔记(个人向) Part.5
  • 【最后203篇系列】032 OpenAI格式调用多模型实验
  • 39.离散化与哈希
  • 数据结构:二叉树的遍历 (Binary Tree Traversals)
  • 杂记 03
  • v-scale-scree: 根据屏幕尺寸缩放内容
  • 基于Python的电影评论数据分析系统 Python+Django+Vue.js
  • 防御保护12-14
  • tmux常用命令
  • Flamingo
  • KingbaseES主备读写分离集群安装教程
  • 字节数据流
  • 北汽新能源半年报:双品牌战略拉动销量增长,多元布局促进转化
  • PIDGen!DecodeProdKey函数分析之四个断点
  • 【大模型应用开发 3.RAG技术应用与Faiss向量数据库】
  • 【leetcode】12. 整数转罗马数字
  • 关于“双指针法“的总结
  • 【Python】Python爬虫学习路线
  • “openfeign“调用接口上传文件报错:Failed to deleted temporary file used for part [file]
  • c++11扩展(c++11并发库)
  • 在职老D渗透日记day18:sqli-labs靶场通关(第26关)get报错注入 过滤or和and基础上又过滤了空格和注释符 ‘闭合 手动注入
  • echarts 画一个饼图,并且外围有一个旋转动画