【大模型训练】roll 调用megatron 计算损失函数有,会用到partial
forward_step_func=partial(self.inner_forward_step, forward_func)这句话什么意思
好的,我们来深入探讨一下 Python 中回调(Callback)的原理,以及在 roll 框架这个具体场景下,回调函数是如何被“触发”的。
Python 回调的基本原理
在 Python 中,函数是一等公民(First-class Citizens)。这意味着函数可以:
- 被赋值给一个变量。
- 作为参数传递给另一个函数。
- 作为另一个函数的返回值。
“回调”正是利用了第二点特性。
核心思想:你定义一个函数 A,然后将另一个函数 B 作为参数传递给 A。在 A 执行过程中的某个特定时间点,A 会去调用(“回调”)你传给它的那个函数 B。
一个简单的例子:
# 定义一个回调函数,它规定了如何处理一个数字
def square_callback(number):result = number * numberprint(f"Callback executed: The square of {number} is {result}")return result# 定义一个主函数,它接收一个数据列表和一个回调函数
def process_data(data_list, callback_function):print("Main function started, processing data...")results = []for item in data_list:# 在这里,主函数调用了(“回调”)传入的函数processed_item = callback_function(item)results.append(processed_item)print("Main function finished.")return results# --- 主程序 ---
my_data = [1, 2, 3, 4]# 调用主函数,并将 square_callback 作为参数传递
final_results = process_data(my_data, square_callback)print(f"Final results: {final_results}")
输出:
Main function started, processing data...
Callback executed: The square of 1 is 1
Callback executed: The square of 2 is 4
Callback executed: The square of 3 is 9
Callback executed: The square of 4 is 16
Main function finished.
Final results: [1, 4, 9, 16]
在这个例子中:
process_data是主调函数(Caller),类似于strategy.forward_step。square_callback是回调函数(Callback),类似于forward_func_log_probs。process_data只负责遍历数据这个通用流程,它不知道具体要对数据做什么运算。square_callback定义了具体的运算(求平方)。- 当
process_data运行到processed_item = callback_function(item)这一行时,回调就被触发了。实际上,callback_function此时就是square_callback的一个引用。
在 roll 框架中的具体触发过程
现在,我们把这个原理应用到 roll 的代码中。
参与者:
- 主调函数 (Caller):
strategy.forward_step - 回调函数 (Callback):
self.forward_func_log_probs(即ActorWorker的一个方法) - 触发点:
inner_forward_step内部的return output_tensor, partial(loss_func, data)和forward_backward_func的内部实现。
让我们追踪一下调用的路径,看看回调是如何被触发的。
路径 1: compute_log_probs -> forward_step
# ActorWorker.py
def compute_log_probs(self, data: DataProto):# ...# 这里,self.forward_func_log_probs 被当作一个值(一个可调用对象)传递给了 forward_stepresults = self.strategy.forward_step(batch=data, forward_func=self.forward_func_log_probs # <--- 传递回调)# ...
在 forward_step 的定义中,它接收了这个回调函数,并将其命名为 forward_func。
# MegatronInferStrategy.py (或类似的 Strategy 类)
def forward_step(self, batch: DataProto, forward_func: Callable): # <--- 接收回调# ...# 它将 forward_func 进一步传递下去losses_reduced = self.forward_backward_func(forward_step_func=partial(self.inner_forward_step, forward_func), # <--- 再次传递# ...)# ...
路径 2: forward_step -> forward_backward_func -> inner_forward_step
forward_backward_func 是 Megatron-LM 框架中的一个函数,它封装了处理流水线并行(Pipeline Parallelism)和微批次(Micro-batching)的复杂逻辑。它的核心作用是循环调用你提供给它的 forward_step_func。
在我们的例子中,forward_step_func 是 partial(self.inner_forward_step, forward_func)。这意味着 forward_backward_func 在其内部循环中会执行类似这样的代码:
# forward_backward_func 的伪代码
def forward_backward_func(forward_step_func, data_iterator, ...):all_outputs = []# 循环处理每个微批次for i in range(num_microbatches):# *** 触发点 1 ***# 调用我们通过 partial 构造的函数# 这会执行 inner_forward_step(data_iterator, model)output_tensor, process_fn = forward_step_func(data_iterator, model)# *** 触发点 2 ***# process_fn 就是 partial(loss_func, data)# 也就是 partial(forward_func_log_probs, data)# 这里,回调函数 forward_func_log_probs 被真正地执行了!loss, metrics = process_fn(output_tensor)all_outputs.append(metrics)return all_outputs
现在我们看 inner_forward_step 做了什么:
# MegatronInferStrategy.py
def inner_forward_step(self, loss_func, data_iterator, model): # loss_func 就是 forward_func_log_probs# ... (准备数据)# 1. 执行模型前向传播output_tensor = model(input_ids=..., attention_mask=...)# 2. 准备回调# 它并不直接调用 loss_func,而是返回一个 partial 对象# partial(loss_func, data) 的意思是:# “创建一个新的、可调用的对象,它等价于调用 loss_func(data, ...),只是第一个参数 data 已经预先填好了”return output_tensor, partial(loss_func, data)
总结一下触发过程:
compute_log_probs将forward_func_log_probs这个方法对象作为参数传给了forward_step。forward_step将这个方法对象进一步传给了底层的forward_backward_func。forward_backward_func在其微批次循环中,调用了inner_forward_step。inner_forward_step执行了模型的前向传播,得到output_tensor。inner_forward_step并不直接调用forward_func_log_probs。相反,它创建并返回了一个partial对象,这个对象“包裹”了forward_func_log_probs和当前的微批次数据data。forward_backward_func拿到了output_tensor和这个partial对象(我们叫它process_fn)。- 真正的触发点:
forward_backward_func调用process_fn(output_tensor)。由于process_fn是一个包裹了forward_func_log_probs的partial对象,这一步就等价于执行:
forward_func_log_probs(data, output_tensor)
此时,ActorWorker中定义的回调函数就被成功触发了,它拿到了模型输出output_tensor和对应的输入数据data,开始执行它自己的特定计算(即调用op_compute_log_probs)。
这个稍微有些绕的设计,是为了将 Megatron-LM 自身的复杂循环逻辑与 roll 框架的上层业务逻辑解耦。forward_backward_func 只需要知道它需要一个能返回 (tensor, callable) 的函数即可,而不需要关心那个 callable 到底是什么。
这里为了适配megatron 中的 forward_step_func, 他forward_step_func只需要传入data_iter, 与model.
但是他inner_forward_step(self, loss_func, data_iterator, model):三个函数,所以要
forward_step_func=partial(self.inner_forward_step, forward_func)
def get_forward_backward_func():"""Retrieves the appropriate forward_backward function given theconfiguration of parallel_state.Returns a function that will perform all of the forward andbackward passes of the model given the pipeline model parallelworld size and virtual pipeline model parallel world size in theglobal parallel_state.Note that if using sequence parallelism, the sequence length component ofthe tensor shape is updated to original_sequence_length /tensor_model_parallel_world_size.The function returned takes the following arguments:forward_step_func (required): A function that takes a dataiterator and a model as its arguments and return the model'sforward output and the loss function. The loss function shouldtake one torch.Tensor and return a torch.Tensor of loss and adictionary of string -> torch.Tensor.A third argument, checkpoint_activations_microbatch, indicatesthat the activations for this microbatch should becheckpointed. A None value for this argument indicates thatthe default from the configuration should be used. This isused when thenum_microbatches_with_partial_activation_checkpoints is used.For example:def loss_func(loss_mask, output_tensor):losses = output_tensor.float()loss_mask = loss_mask.view(-1).float()loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()# Reduce loss for logging.averaged_loss = average_losses_across_data_parallel_group([loss])return loss, {'lm loss': averaged_loss[0]}def forward_step(data_iterator, model):data, loss_mask = next(data_iterator)output = model(data)return output, partial(loss_func, loss_mask)forward_backward_func(forward_step_func=forward_step, ...)
您询问的是 forward_backward_func 中 forward_step_func 的使用方式,以及它返回的 partial(loss_func, loss_mask, model=model) 的用途。
forward_step_func 的使用方式
forward_step_func 是用户定义的前向传播函数,在 forward_backward_func 内部被调用。 1
调用位置
在 forward_step() 函数中,forward_step_func 被调用: 2
具体调用方式:
output_tensor, loss_func = forward_step_func(data_iterator, model)
这里 forward_step_func 接收两个参数:
data_iterator: 数据迭代器model: 模型实例
返回两个值:
output_tensor: 模型的输出(通常是损失张量)loss_func: 一个偏函数,用于后续的损失处理
partial(loss_func, ...) 的用途
返回的 partial(loss_func, loss_mask, model=model) 被用于 forward_step_calc_loss() 函数中。 3
使用流程
-
接收 loss_func: 4
-
调用 loss_func: 在
forward_step_calc_loss()中,如果是最后一个 pipeline stage,会调用这个loss_func: 5 -
处理返回值:
loss_func被调用时传入output_tensor,返回处理后的损失和元数据。
实际示例
在 pretrain_gpt.py 中可以看到完整的使用示例: 6
这里:
forward_step_func调用模型得到output_tensor- 返回
partial(loss_func, loss_mask, model=model),将loss_mask和model预先绑定 - 后续在
forward_step_calc_loss()中,这个偏函数被调用时只需传入output_tensor,就能完成损失计算
为什么使用 partial
使用 partial 的好处是:
- 延迟执行: 在
forward_step_func中不立即计算最终损失,而是返回一个函数 - 参数绑定: 将
loss_mask和model等上下文信息提前绑定,后续只需传入output_tensor - 灵活性: 允许在不同的 pipeline stage 有不同的损失处理逻辑
Notes
这种设计模式在所有训练脚本中都是一致的,包括 pretrain_mamba.py 7 , pretrain_t5.py 8 等。这种统一的接口设计使得 Megatron-LM 能够支持不同类型的模型和训练任务。
Wiki pages you might want to explore:
- Architecture Overview (NVIDIA/Megatron-LM)
Citations
File: megatron/core/pipeline_parallel/schedules.py (L199-287)
def forward_step_calc_loss(model,output_tensor,loss_func,config,vp_stage,collect_non_loss_data,num_microbatches,forward_data_store,cp_group_size=None,is_last_stage=None,
):"""Calculate the loss and number of tokens for forward_step()"""from megatron.core.transformer.multi_token_prediction import MTPLossAutoScalermodel_vp_stage = getattr(model, "vp_stage", None)if vp_stage is not None and model_vp_stage is not None:assert (vp_stage == model_vp_stage), f"vp_stage ({vp_stage}) doesn't match model_vp_stage ({model_vp_stage})"if cp_group_size is None and is_last_stage is None:# fallback to parallel statecp_group_size = parallel_state.get_context_parallel_world_size()is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage)else:assert (cp_group_size is not None and is_last_stage is not None), "cp_group_size and is_last_stage must be provided"num_tokens = torch.tensor(0, dtype=torch.int)if is_last_stage:if not collect_non_loss_data:outputs = loss_func(output_tensor)if len(outputs) == 3:output_tensor, num_tokens, loss_reduced = outputsif not config.calculate_per_token_loss:# Protect against division by zero when all tokens are masked# in a microbatch.output_tensor /= torch.clamp(num_tokens, min=1)output_tensor /= num_microbatcheselse:# preserve legacy loss averaging behavior (ie, over the number of microbatches)assert len(outputs) == 2output_tensor, loss_reduced = outputsoutput_tensor *= cp_group_sizeoutput_tensor /= num_microbatchesforward_data_store.append(loss_reduced)else:data = loss_func(output_tensor, non_loss_data=True)forward_data_store.append(data)if config.timers is not None:config.timers('forward-compute').stop()# Set the loss scale for the auxiliary loss of the MoE layer.# Since we use a trick to do backward on the auxiliary loss, we need to set the scale# explicitly.if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:# Calculate the loss scale based on the grad_scale_func if available, else default to 1.loss_scale = (config.grad_scale_func(torch.ones(1, device=output_tensor.device))if config.grad_scale_func is not Noneelse torch.ones(1, device=output_tensor.device))# Set the loss scaleif config.calculate_per_token_loss:MoEAuxLossAutoScaler.set_loss_scale(loss_scale)else:MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)# Set the loss scale for Multi-Token Prediction (MTP) loss.if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:# Calculate the loss scale based on the grad_scale_func if available, else default to 1.loss_scale = (config.grad_scale_func(torch.ones(1, device=output_tensor.device))if config.grad_scale_func is not Noneelse torch.ones(1, device=output_tensor.device))# Set the loss scaleif config.calculate_per_token_loss:MTPLossAutoScaler.set_loss_scale(loss_scale)else:MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)return output_tensor, num_tokens
File: megatron/core/pipeline_parallel/schedules.py (L290-422)
def forward_step(forward_step_func,data_iterator,model,num_microbatches,input_tensor,forward_data_store,config,cp_group_size,collect_non_loss_data=False,checkpoint_activations_microbatch=None,is_first_microbatch=False,current_microbatch=None,vp_stage=None,is_last_stage=True,
):"""Forward step for passed-in model.If it is the first stage, the input tensor is obtained from the data_iterator.Otherwise, the passed-in input_tensor is used.Args:forward_step_func (callable):The forward step function for the model that takes thedata iterator as the first argument, and model as the second.This user's forward step is expected to output a tuple of two elements:1. The output object from the forward step. This output object needs to be atensor or some kind of collection of tensors. The only hard requirementfor this object is that it needs to be acceptible as input into the secondfunction.2. A function to reduce (optionally) the output from the forward step. Thiscould be a reduction over the loss from the model, it could be a function thatgrabs the output from the model and reformats, it could be a function that justpasses through the model output. This function must have one of the followingpatterns, and depending on the pattern different things happen internally:a. A tuple of reduced loss and some other data. Note that in this casethe first argument is divided by the number of global microbatches,assuming it is a loss, so that the loss is stable as a function ofthe number of devices the step is split across.b. A triple of reduced loss, number of tokens, and some other data. Thisis similar to case (a), but the loss is further averaged across thenumber of tokens in the batch. If the user is not already averagingacross the number of tokens, this pattern is useful to use.c. Any arbitrary data the user wants (eg a dictionary of tensors, a listof tensors, etc in the case of inference). To trigger case 3 you needto specify `collect_non_loss_data=True` and you may also want tospecify `forward_only=True` in the call to the parent forward_backwardfunction.data_iterator (iterator):The data iterator.model (nn.Module):The model to perform the forward step on.num_microbatches (int):The number of microbatches.input_tensor (Tensor or list[Tensor]):The input tensor(s) for the forward step.forward_data_store (list):The list to store the forward data. If you go down path 2.a or2.b for the return of your forward reduction function then this will store only thefinal dimension of the output, for example the metadata output by the loss function.If you go down the path of 2.c then this will store the entire output of the forwardreduction function applied to the model output.config (object):The configuration object.collect_non_loss_data (bool, optional):Whether to collect non-loss data. Defaults to False.This is the path to use if you want to collect arbitrary output from the model forward,such as with inference use cases. Defaults to False.checkpoint_activations_microbatch (int, optional):The microbatch to checkpoint activations.Defaults to None.is_first_microbatch (bool, optional):Whether it is the first microbatch. Defaults to False.current_microbatch (int, optional):The current microbatch. Defaults to None.vp_stage (int, optional):The virtual pipeline stage. Defaults to None.is_last_stage (bool, optional):Whether it is the last stage. Defaults to True.Also considering virtual stages.In case of PP/VPP, is_last_stage/is_vp_last_stage.Returns:Tensor or list[Tensor]: The output object(s) from the forward step.Tensor: The number of tokens."""from megatron.core.transformer.multi_token_prediction import MTPLossAutoScalerif config.timers is not None:config.timers('forward-compute', log_level=2).start()if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):model.set_is_first_microbatch()if current_microbatch is not None:set_current_microbatch(model, current_microbatch)unwrap_output_tensor = Falseif not isinstance(input_tensor, list):input_tensor = [input_tensor]unwrap_output_tensor = Trueset_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")set_input_tensor(input_tensor)if config.enable_autocast:context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)else:context_manager = contextlib.nullcontext()with context_manager:if checkpoint_activations_microbatch is None:output_tensor, loss_func = forward_step_func(data_iterator, model)else:output_tensor, loss_func = forward_step_func(data_iterator, model, checkpoint_activations_microbatch)output_tensor, num_tokens = forward_step_calc_loss(model,output_tensor,loss_func,config,vp_stage,collect_non_loss_data,num_microbatches,forward_data_store,cp_group_size,is_last_stage,)if unwrap_output_tensor:return output_tensor, num_tokensreturn [output_tensor], num_tokens
File: pretrain_gpt.py (L121-157)
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Modelreturn_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):vp_stage = get_attr_wrapped_model(model, "vp_stage")tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)timers('batch-generator').stop()with stimer:if args.use_legacy_models:output_tensor = model(tokens, position_ids, attention_mask, labels=labels)else:if return_schedule_plan:assert args.overlap_moe_expert_parallel_comm, \"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"schedule_plan = model.build_schedule_plan(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)return schedule_plan, partial(loss_func, loss_mask, model=model)else:output_tensor = model(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)# [ModelOpt]: model is needed to access ModelOpt distillation lossesreturn output_tensor, partial(loss_func, loss_mask, model=model)
File: pretrain_mamba.py (L131-154)
def forward_step(data_iterator, model: MambaModel):"""Forward training step.Args:data_iterator : Input data iteratormodel (MambaModel): The GPT Model"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):vp_stage = get_attr_wrapped_model(model, "vp_stage")tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)timers('batch-generator').stop()with stimer:output_tensor = model(tokens, position_ids, attention_mask,labels=labels)# [ModelOpt]: model is needed to access ModelOpt distillation lossesreturn output_tensor, partial(loss_func, loss_mask, model=model)
File: pretrain_t5.py (L174-198)
def forward_step(data_iterator, model: T5Model):"""Forward training step.Args:data_iterator : Input data iteratormodel (T5Model): The T5 Model"""args = get_args()timers = get_timers()# Get the batch.timers('batch generator', log_level=2).start()use_local = args.transformer_impl == "local"tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = get_batch(data_iterator, use_local)timers('batch generator').stop()# Forward model lm_labelsoutput_tensor = model(tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, lm_labels=lm_labels)return output_tensor, partial(loss_func, loss_mask)
