【大模型训练】megatron分布式并行训练的调用流程,关键函数forward_backward_func
roll 强化学习框架 怎么调用到 megatron 进行模型执行的呢?

megatron 这个forward_backward_func 接口,对外使用的,会做前向,backward 和optimizer.step.
roll 可以获取forward_backward_func 接口
总结一下, forward_step_func 这个函数可以自己定义, 调用这个函数最终会调用到megatron 模型的forward 函数,进行推理。
比roll 中。forward_step_func, 自定义为inner_forward_step, 这个函数最终会调用到model.forward. 这里用了partial 语法,C++中的模版类。


megatron 中也可以使用这个接口forward_backward_func
您询问的是 Megatron-LM 从模型前向传播到损失函数计算再到梯度更新的完整流程。
整体流程概览
Megatron-LM 的训练流程从 train_step() 开始,经过前向传播、损失计算、反向传播,最后进行参数更新。 1
详细执行步骤
1. 训练步骤入口
train_step() 是单次训练迭代的入口函数: 2
首先清零梯度缓冲区和优化器梯度: 3
2. 前向和反向传播
调用 forward_backward_func 执行前向和反向传播: 4
forward_backward_func 是通过 get_forward_backward_func() 获取的调度函数,根据 pipeline 配置选择不同的实现: 5
3. 前向传播详细流程
3.1 调用 forward_step()
在调度函数内部,会调用 forward_step() 处理每个 microbatch: 6
3.2 执行用户定义的 forward_step_func
forward_step() 内部调用用户提供的 forward_step_func: 7
以 pretrain_gpt.py 为例,forward_step_func 执行以下操作: 8
- 获取 batch 数据 (tokens, labels, loss_mask 等)
- 调用
model()进行前向传播,传入labels参数 - 返回
output_tensor(损失张量) 和partial(loss_func, loss_mask, model=model)
3.3 模型前向传播
当调用 model(tokens, position_ids, attention_mask, labels=labels) 时,执行 GPTModel.forward(): 9
模型前向传播包括:
- Embedding 层处理输入
- Transformer decoder 处理 hidden states
_postprocess()生成 logits 并计算损失
如果启用了 MTP,在 _postprocess() 中会额外处理 MTP 层并通过 MTPLossAutoScaler 附加辅助损失。 10
4. 损失函数计算
4.1 调用 forward_step_calc_loss()
forward_step() 获取 output_tensor 后,调用 forward_step_calc_loss() 处理损失: 11
4.2 执行 loss_func
在 forward_step_calc_loss() 中,如果是最后一个 pipeline stage,会调用之前返回的 loss_func: 12
loss_func (来自 pretrain_gpt.py) 执行: 13
- 应用 loss mask
- 计算加权损失和
- 跨数据并行组归约损失用于日志记录
- 返回
(loss, num_tokens, {'lm loss': reporting_loss})
4.3 设置辅助损失缩放
对于 MoE 和 MTP,forward_step_calc_loss() 会设置辅助损失的缩放因子: 14
这确保辅助损失的梯度与主损失使用相同的缩放。
5. 反向传播
5.1 调用 backward_step()
在非 pipeline 并行的情况下,每个 microbatch 的前向传播后立即执行反向传播: 15
backward_step() 执行标准的 PyTorch 反向传播: 16
关键步骤:
- 保留输入张量的梯度
- 如果是最后一个 stage,应用梯度缩放
- 调用
torch.autograd.backward()计算梯度 - 收集输入张量的梯度用于 pipeline 传递
5.2 MTP 辅助损失的反向传播
MTP 的辅助损失通过 MTPLossAutoScaler 自动触发反向传播: 17
在 backward() 中,生成缩放后的梯度并返回,使得 MTP 损失的梯度自动累加到主模型的梯度流中。
6. 梯度同步和参数更新
6.1 梯度同步
反向传播完成后,如果启用了数据并行,会进行梯度同步。对于 interleaved pipeline,在 cooldown 阶段会启动剩余的梯度归约: 18
然后调用 finalize_model_grads_func 完成梯度的最终处理: 19
6.2 优化器更新
回到 train_step(),调用 optimizer.step() 更新参数: 20
优化器会:
- 应用梯度裁剪
- 执行参数更新 (如 Adam)
- 返回更新是否成功、梯度范数等信息
最后跨模型并行组同步更新状态: 21
完整流程图
Notes
这个流程支持多种并行策略 (TP, PP, DP, CP, EP) 的组合。对于 pipeline 并行,会使用 1F1B (one-forward-one-backward) 调度来减少 pipeline bubble。 22 对于 MTP 等辅助损失,通过自定义 autograd 函数优雅地集成到主训练流程中,无需显式管理多个损失项。
Wiki pages you might want to explore:
- Architecture Overview (NVIDIA/Megatron-LM)
Citations
File: megatron/training/training.py (L1228-1297)
def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func):"""Single training step."""args = get_args()timers = get_timers()rerun_state_machine = get_rerun_state_machine()while rerun_state_machine.should_run_forward_backward(data_iterator):# Set grad to zero.for model_chunk in model:model_chunk.zero_grad_buffer()optimizer.zero_grad()if has_nvidia_modelopt:# [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensorsadjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(model, args.seq_length, args.micro_batch_size, args.decoder_seq_length)else:adjust_tensor_shapes_fn = None# For the mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap,# we need to call the _copy_main_params_to_param_buffer() after the grad buffer# is zeroed by zero_grad_buffer() because param and grad buffer are shared.if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather:for optim_instance in optimizer.chained_optimizers:if isinstance(optim_instance, DistributedOptimizer):optim_instance._copy_main_params_to_param_buffer()# Forward pass.losses_reduced = forward_backward_func(forward_step_func=forward_step_func,data_iterator=data_iterator,model=model,num_microbatches=get_num_microbatches(),seq_length=args.seq_length,micro_batch_size=args.micro_batch_size,decoder_seq_length=args.decoder_seq_length,forward_only=False,adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,)should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()if should_exit:return {}, True, should_checkpoint, should_exit, exit_code, None, None# Empty unused memory.if args.empty_unused_memory_level >= 1:torch.cuda.empty_cache()# Vision gradients.if args.vision_pretraining and args.vision_pretraining_type == "dino":unwrapped_model = unwrap_model(model[0])unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)# Update parameters.timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)update_successful, grad_norm, num_zeros_in_grad = optimizer.step()timers('optimizer').stop()# when freezing sub-models we may have a mixture of successful and unsucessful ranks,# so we must gather across mp ranksupdate_successful = logical_and_across_model_parallel_group(update_successful)# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,# so we must gather across mp ranksgrad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)if args.log_num_zeros_in_grad:num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)# Vision momentum.if args.vision_pretraining and args.vision_pretraining_type == "dino":
File: megatron/core/pipeline_parallel/schedules.py (L40-132)
def get_forward_backward_func():"""Retrieves the appropriate forward_backward function given theconfiguration of parallel_state.Returns a function that will perform all of the forward andbackward passes of the model given the pipeline model parallelworld size and virtual pipeline model parallel world size in theglobal parallel_state.Note that if using sequence parallelism, the sequence length component ofthe tensor shape is updated to original_sequence_length /tensor_model_parallel_world_size.The function returned takes the following arguments:forward_step_func (required): A function that takes a dataiterator and a model as its arguments and return the model'sforward output and the loss function. The loss function shouldtake one torch.Tensor and return a torch.Tensor of loss and adictionary of string -> torch.Tensor.A third argument, checkpoint_activations_microbatch, indicatesthat the activations for this microbatch should becheckpointed. A None value for this argument indicates thatthe default from the configuration should be used. This isused when thenum_microbatches_with_partial_activation_checkpoints is used.For example:def loss_func(loss_mask, output_tensor):losses = output_tensor.float()loss_mask = loss_mask.view(-1).float()loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()# Reduce loss for logging.averaged_loss = average_losses_across_data_parallel_group([loss])return loss, {'lm loss': averaged_loss[0]}def forward_step(data_iterator, model):data, loss_mask = next(data_iterator)output = model(data)return output, partial(loss_func, loss_mask)forward_backward_func(forward_step_func=forward_step, ...)data_iterator (required): an iterator over the data, will bepassed as is to forward_step_func. Expected to be a list ofiterators in the case of interleaved pipeline parallelism.model (required): the actual model. Expected to be a list of modules in the case of interleavedpipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.num_microbatches (int, required):The number of microbatches to go throughseq_length (int, required): Sequence length of the current global batch. If this is a dual-stacktransformer, this is the encoder's sequence length. This is ignored if variable_seq_lengthsin the config is True. Otherwise, each microbatch in the current global batch size must usethis sequence length.micro_batch_size (int, required): The number of sequences in a microbatch.decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stacktransformer. This is ignored for a single-stack transformer.forward_only (optional, default = False): Perform only the forward stepcollect_non_loss_data (optional, bool, default=False): TODOfirst_val_step (bool, optional): Is the first step of the validation phase. Used byTransformer Engine modules to only update their fp8 weights only on the first validationstep.adjust_tensor_shapes_fn (Callable, optional): A function that adjusts the receive and sendtensor shapes. Only applicable in forward_backward_pipelining_without_interleaving for now.Takes in a list of receive shapes and a list of send shapes and returns the adjustedrespective list of shapes. Thus it is not used in the other forward-backward functionswhich have different shape handling."""pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()if pipeline_model_parallel_size > 1:if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:forward_backward_func = forward_backward_pipelining_with_interleavingelse:forward_backward_func = forward_backward_pipelining_without_interleavingelse:forward_backward_func = forward_backward_no_pipeliningreturn forward_backward_func
File: megatron/core/pipeline_parallel/schedules.py (L232-248)
num_tokens = torch.tensor(0, dtype=torch.int)if is_last_stage:if not collect_non_loss_data:outputs = loss_func(output_tensor)if len(outputs) == 3:output_tensor, num_tokens, loss_reduced = outputsif not config.calculate_per_token_loss:# Protect against division by zero when all tokens are masked# in a microbatch.output_tensor /= torch.clamp(num_tokens, min=1)output_tensor /= num_microbatcheselse:# preserve legacy loss averaging behavior (ie, over the number of microbatches)assert len(outputs) == 2output_tensor, loss_reduced = outputsoutput_tensor *= cp_group_sizeoutput_tensor /= num_microbatches
File: megatron/core/pipeline_parallel/schedules.py (L257-285)
# Set the loss scale for the auxiliary loss of the MoE layer.# Since we use a trick to do backward on the auxiliary loss, we need to set the scale# explicitly.if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:# Calculate the loss scale based on the grad_scale_func if available, else default to 1.loss_scale = (config.grad_scale_func(torch.ones(1, device=output_tensor.device))if config.grad_scale_func is not Noneelse torch.ones(1, device=output_tensor.device))# Set the loss scaleif config.calculate_per_token_loss:MoEAuxLossAutoScaler.set_loss_scale(loss_scale)else:MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)# Set the loss scale for Multi-Token Prediction (MTP) loss.if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:# Calculate the loss scale based on the grad_scale_func if available, else default to 1.loss_scale = (config.grad_scale_func(torch.ones(1, device=output_tensor.device))if config.grad_scale_func is not Noneelse torch.ones(1, device=output_tensor.device))# Set the loss scaleif config.calculate_per_token_loss:MTPLossAutoScaler.set_loss_scale(loss_scale)else:MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
File: megatron/core/pipeline_parallel/schedules.py (L290-422)
def forward_step(forward_step_func,data_iterator,model,num_microbatches,input_tensor,forward_data_store,config,cp_group_size,collect_non_loss_data=False,checkpoint_activations_microbatch=None,is_first_microbatch=False,current_microbatch=None,vp_stage=None,is_last_stage=True,
):"""Forward step for passed-in model.If it is the first stage, the input tensor is obtained from the data_iterator.Otherwise, the passed-in input_tensor is used.Args:forward_step_func (callable):The forward step function for the model that takes thedata iterator as the first argument, and model as the second.This user's forward step is expected to output a tuple of two elements:1. The output object from the forward step. This output object needs to be atensor or some kind of collection of tensors. The only hard requirementfor this object is that it needs to be acceptible as input into the secondfunction.2. A function to reduce (optionally) the output from the forward step. Thiscould be a reduction over the loss from the model, it could be a function thatgrabs the output from the model and reformats, it could be a function that justpasses through the model output. This function must have one of the followingpatterns, and depending on the pattern different things happen internally:a. A tuple of reduced loss and some other data. Note that in this casethe first argument is divided by the number of global microbatches,assuming it is a loss, so that the loss is stable as a function ofthe number of devices the step is split across.b. A triple of reduced loss, number of tokens, and some other data. Thisis similar to case (a), but the loss is further averaged across thenumber of tokens in the batch. If the user is not already averagingacross the number of tokens, this pattern is useful to use.c. Any arbitrary data the user wants (eg a dictionary of tensors, a listof tensors, etc in the case of inference). To trigger case 3 you needto specify `collect_non_loss_data=True` and you may also want tospecify `forward_only=True` in the call to the parent forward_backwardfunction.data_iterator (iterator):The data iterator.model (nn.Module):The model to perform the forward step on.num_microbatches (int):The number of microbatches.input_tensor (Tensor or list[Tensor]):The input tensor(s) for the forward step.forward_data_store (list):The list to store the forward data. If you go down path 2.a or2.b for the return of your forward reduction function then this will store only thefinal dimension of the output, for example the metadata output by the loss function.If you go down the path of 2.c then this will store the entire output of the forwardreduction function applied to the model output.config (object):The configuration object.collect_non_loss_data (bool, optional):Whether to collect non-loss data. Defaults to False.This is the path to use if you want to collect arbitrary output from the model forward,such as with inference use cases. Defaults to False.checkpoint_activations_microbatch (int, optional):The microbatch to checkpoint activations.Defaults to None.is_first_microbatch (bool, optional):Whether it is the first microbatch. Defaults to False.current_microbatch (int, optional):The current microbatch. Defaults to None.vp_stage (int, optional):The virtual pipeline stage. Defaults to None.is_last_stage (bool, optional):Whether it is the last stage. Defaults to True.Also considering virtual stages.In case of PP/VPP, is_last_stage/is_vp_last_stage.Returns:Tensor or list[Tensor]: The output object(s) from the forward step.Tensor: The number of tokens."""from megatron.core.transformer.multi_token_prediction import MTPLossAutoScalerif config.timers is not None:config.timers('forward-compute', log_level=2).start()if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):model.set_is_first_microbatch()if current_microbatch is not None:set_current_microbatch(model, current_microbatch)unwrap_output_tensor = Falseif not isinstance(input_tensor, list):input_tensor = [input_tensor]unwrap_output_tensor = Trueset_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")set_input_tensor(input_tensor)if config.enable_autocast:context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)else:context_manager = contextlib.nullcontext()with context_manager:if checkpoint_activations_microbatch is None:output_tensor, loss_func = forward_step_func(data_iterator, model)else:output_tensor, loss_func = forward_step_func(data_iterator, model, checkpoint_activations_microbatch)output_tensor, num_tokens = forward_step_calc_loss(model,output_tensor,loss_func,config,vp_stage,collect_non_loss_data,num_microbatches,forward_data_store,cp_group_size,is_last_stage,)if unwrap_output_tensor:return output_tensor, num_tokensreturn [output_tensor], num_tokens
File: megatron/core/pipeline_parallel/schedules.py (L425-486)
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):"""Backward step through passed-in output tensor.If last stage, output_tensor_grad is None, otherwise gradient of losswith respect to stage's output tensor.Returns gradient of loss with respect to input tensor (None if firststage)."""# NOTE: This code currently can handle at most one skip connection. It# needs to be modified slightly to support arbitrary numbers of skip# connections.if config.timers is not None:config.timers('backward-compute', log_level=2).start()# Retain the grad on the input_tensor.unwrap_input_tensor_grad = Falseif not isinstance(input_tensor, list):input_tensor = [input_tensor]unwrap_input_tensor_grad = Truefor x in input_tensor:if x is not None:x.retain_grad()if not isinstance(output_tensor, list):output_tensor = [output_tensor]if not isinstance(output_tensor_grad, list):output_tensor_grad = [output_tensor_grad]# Backward pass.if output_tensor_grad[0] is None and config.grad_scale_func is not None:output_tensor[0] = config.grad_scale_func(output_tensor[0])# In multi-modal models like VLM, some batches may not have images.# When no image is present, the vision encoder (as a separate pipeline stage)# will not participate in the computation.# This results in a tensor that does not require gradients.# In such cases, we intentionally skip the backward pass while preserving zero gradients.if output_tensor[0].requires_grad:if config.deallocate_pipeline_outputs:custom_backward(output_tensor[0], output_tensor_grad[0])else:torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])# Collect the grad of the input_tensor.input_tensor_grad = [None]if input_tensor is not None:input_tensor_grad = []for x in input_tensor:if x is None:input_tensor_grad.append(None)else:input_tensor_grad.append(x.grad)if unwrap_input_tensor_grad:input_tensor_grad = input_tensor_grad[0]if config.timers is not None:config.timers('backward-compute').stop()return input_tensor_grad
File: megatron/core/pipeline_parallel/schedules.py (L593-634)
with no_sync_func():for i in range(num_microbatches - 1):output_tensor, num_tokens = forward_step(forward_step_func,data_iterator,model,num_microbatches,input_tensor,forward_data_store,config,pg_collection.cp.size(),collect_non_loss_data,is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),current_microbatch=i,)total_num_tokens += num_tokensif not forward_only:backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)# Run computation for last microbatch out of context handler (want to# synchronize gradients).output_tensor, num_tokens = forward_step(forward_step_func,data_iterator,model,num_microbatches,input_tensor,forward_data_store,config,pg_collection.cp.size(),collect_non_loss_data,is_first_microbatch=check_first_val_step(first_val_step, forward_only, num_microbatches == 1),current_microbatch=num_microbatches - 1,)total_num_tokens += num_tokensif not forward_only:backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
File: megatron/core/pipeline_parallel/schedules.py (L1870-1877)
# Launch any remaining grad reductions.enable_grad_sync()if config.grad_sync_func is not None:for model_chunk_id in range(num_model_chunks):if model_chunk_id not in synchronized_model_chunks:config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())synchronized_model_chunks.add(model_chunk_id)nvtx_range_pop(suffix="cooldown")
File: megatron/core/pipeline_parallel/schedules.py (L1887-1903)
if config.finalize_model_grads_func is not None and not forward_only:# If defer_embedding_wgrad_compute is enabled we need to do the# weight gradient GEMM's here.finish_embedding_wgrad_compute(config, embedding_module, is_pp_last_stage(p2p_communicator.pp_group), tp_group)# Finalize model grads (perform full grad all-reduce / reduce-scatter for# data parallelism, layernorm all-reduce for sequence parallelism, and# embedding all-reduce for pipeline parallelism).config.finalize_model_grads_func(model,total_num_tokens if config.calculate_per_token_loss else None,pg_collection=pg_collection,)
File: megatron/core/pipeline_parallel/schedules.py (L1949-1965)
def forward_backward_pipelining_without_interleaving(*,forward_step_func,data_iterator: Union[Iterator, List[Iterator]],model: Union[torch.nn.Module, List[torch.nn.Module]],num_microbatches: int,seq_length: int,micro_batch_size: int,decoder_seq_length: Optional[int] = None,forward_only: bool = False,collect_non_loss_data: bool = False,first_val_step: Optional[bool] = None,adjust_tensor_shapes_fn: Optional[Callable] = None,p2p_communicator: Optional[P2PCommunicator] = None,pg_collection: Optional[ProcessGroupCollection] = None,
):"""Run non-interleaved 1F1B schedule, with communication between pipeline
File: pretrain_gpt.py (L59-118)
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):"""Loss function.Args:loss_mask (torch.Tensor): Used to mask out some portions of the lossoutput_tensor (torch.Tensor): The tensor with the lossesmodel (GPTModel, optional): The model (can be wrapped)Returns:the loss scalar for this micro-batchthe number of non-padded tokens in this microbatcha dict containing reporting metrics on the loss and number of tokens acrossthe data parallel ranks"""args = get_args()if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]return loss_func_modelopt(loss_mask, output_tensor, model=model)losses = output_tensor.view(-1).float()loss_mask = loss_mask.view(-1).float()loss = torch.sum(losses * loss_mask)# Check individual rank losses are not NaN prior to DP all-reduce.rerun_state_machine = get_rerun_state_machine()if args.check_for_nan_in_loss_and_grad:rerun_state_machine.validate_result(result=loss,rejection_func=torch.isnan,message="found NaN in local forward loss calculation",tolerance=0.0, # forward pass calculations are determinisicfatal=True,)rerun_state_machine.validate_result(result=loss,rejection_func=torch.isinf,message="found Inf in local forward loss calculation",tolerance=0.0, # forward pass calculations are determinisicfatal=True,)# Check for spiky lossif args.check_for_spiky_loss:rerun_state_machine.validate_result(result=loss,rejection_func=partial(rerun_state_machine.is_unexpectedly_large,threshold=SPIKY_LOSS_FACTOR,context="loss",),message="Spiky loss",tolerance=0.0, # forward pass calculations are determinisicfatal=False,)num_tokens = loss_mask.sum().clone().detach().to(torch.int)reporting_loss = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])return (loss, num_tokens, {'lm loss': reporting_loss})
File: pretrain_gpt.py (L121-157)
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Modelreturn_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):vp_stage = get_attr_wrapped_model(model, "vp_stage")tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)timers('batch-generator').stop()with stimer:if args.use_legacy_models:output_tensor = model(tokens, position_ids, attention_mask, labels=labels)else:if return_schedule_plan:assert args.overlap_moe_expert_parallel_comm, \"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"schedule_plan = model.build_schedule_plan(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)return schedule_plan, partial(loss_func, loss_mask, model=model)else:output_tensor = model(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)# [ModelOpt]: model is needed to access ModelOpt distillation lossesreturn output_tensor, partial(loss_func, loss_mask, model=model)
File: megatron/core/models/gpt/gpt_model.py (L486-580)
def _postprocess(self,hidden_states,input_ids,position_ids,labels,rotary_pos_emb,rotary_pos_cos,rotary_pos_sin,mtp_in_postprocess=None,loss_mask=None,decoder_input=None,attention_mask=None,inference_params=None,packed_seq_params=None,sequence_len_offset=None,runtime_gather_output=None,extra_block_kwargs=None,inference_context=None,):"""Postprocesses decoder hidden states to generate logits or compute loss.Applies Multi-Token Prediction if enabled, generates output logits throughthe output layer, and computes language model loss when labels are provided."""in_inference_mode = inference_context is not None and not self.trainingif in_inference_mode:assert runtime_gather_output, "Inference must always gather TP logits"# logits and lossoutput_weight = Noneif self.share_embeddings_and_output_weights:output_weight = self.shared_embedding_or_output_weight()if mtp_in_postprocess:hidden_states = self.mtp(input_ids=input_ids,position_ids=position_ids,hidden_states=hidden_states,attention_mask=attention_mask,inference_params=inference_params,rotary_pos_emb=rotary_pos_emb,rotary_pos_cos=rotary_pos_cos,rotary_pos_sin=rotary_pos_sin,packed_seq_params=packed_seq_params,sequence_len_offset=sequence_len_offset,embedding=self.embedding,**(extra_block_kwargs or {}),)if not self.post_process:return hidden_statesif self.mtp_process:mtp_labels = labels.clone()hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)hidden_states = hidden_states_list[0]if loss_mask is None:# if loss_mask is not provided, use all ones as loss_maskloss_mask = torch.ones_like(mtp_labels)for mtp_layer_number in range(self.config.mtp_num_layers):# outputmtp_logits, _ = self.output_layer(hidden_states_list[mtp_layer_number + 1],weight=output_weight,runtime_gather_output=runtime_gather_output,)# Calc loss for the current Multi-Token Prediction (MTP) layers.mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)mtp_loss = loss_mask * mtp_lossif self.training:# TODO(shifangx): remove the use of parallel_state here# after moving loss logging to loss_func in pretrain_gpt.pyMTPLossLoggingHelper.save_loss_to_tracker(torch.sum(mtp_loss) / num_tokens,mtp_layer_number,self.config.mtp_num_layers,avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),)mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layersif self.config.calculate_per_token_loss:hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)else:hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)sequence_parallel_override = False
File: megatron/core/transformer/multi_token_prediction.py (L356-389)
class MTPLossAutoScaler(torch.autograd.Function):"""An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)@staticmethoddef forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor):"""Preserve the mtp by storing it in the context to avoid garbage collection.Args:output (torch.Tensor): The output tensor.mtp_loss (torch.Tensor): The mtp loss tensor.Returns:torch.Tensor: The output tensor."""ctx.save_for_backward(mtp_loss)return output@staticmethoddef backward(ctx, grad_output: torch.Tensor):"""Compute and scale the gradient for mtp loss..Args:grad_output (torch.Tensor): The gradient of the output.Returns:Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp lossgradient."""(mtp_loss,) = ctx.saved_tensorsmtp_loss_backward_scale = MTPLossAutoScaler.main_loss_backward_scalescaled_mtp_loss_grad = torch.ones_like(mtp_loss) * mtp_loss_backward_scalereturn grad_output, scaled_mtp_loss_grad
您询问的是 forward_step_func 是谁传递过来的,以及调用它会执行到哪里。
forward_step_func 的来源
forward_step_func 是在 pretrain() 函数中作为参数传入的。 1
传递路径
-
在训练脚本中定义: 在
pretrain_gpt.py中,定义了forward_step函数 2 -
传递给 pretrain(): 在
pretrain_gpt.py的主函数中,将forward_step作为forward_step_func参数传递给pretrain()2 -
传递给 train_step():
pretrain()函数内部调用train_step()时,将forward_step_func继续传递 3 -
传递给 forward_backward_func:
train_step()调用forward_backward_func时,将forward_step_func作为参数传入 4 -
在调度函数中使用: 调度函数(如
forward_backward_no_pipelining) 接收forward_step_func参数 5
调用 forward_step_func 执行到哪里
当 forward_step_func 被调用时,实际执行的是用户在训练脚本中定义的函数。
调用位置
在 forward_step() 函数中调用 forward_step_func 6
执行内容(以 pretrain_gpt.py 为例)
调用 forward_step_func(data_iterator, model) 会执行 pretrain_gpt.py 中的 forward_step 函数 2
具体执行步骤:
-
获取 batch 数据: 调用
get_batch()获取 tokens, labels, loss_mask 等 7 -
模型前向传播: 调用
model()执行前向传播,传入 tokens, position_ids, attention_mask, labels 8 -
返回结果: 返回
output_tensor和partial(loss_func, loss_mask, model=model)9
其他训练脚本的实现
不同的训练脚本有不同的 forward_step_func 实现:
- pretrain_mamba.py: 类似的实现模式,但针对 Mamba 模型
- train_rl.py: 针对强化学习训练,计算 GRPO 损失 10
- 测试代码: 在单元测试中也定义了简化版的
forward_step_func11
完整调用链
Notes
这种设计模式使得 Megatron-LM 的训练框架非常灵活,用户可以在训练脚本中自定义 forward_step_func 来适配不同的模型和训练任务,而核心的调度逻辑(如 pipeline 并行、梯度累积等)保持不变。 在评估时,evaluate() 函数也使用相同的 forward_step_func,只是设置 forward_only=True 12
Wiki pages you might want to explore:
- Architecture Overview (NVIDIA/Megatron-LM)
Citations
File: megatron/training/training.py (L530-543)
def pretrain(train_valid_test_dataset_provider,model_provider,model_type,forward_step_func,process_non_loss_data_func=None,extra_args_provider=None,args_defaults={},get_embedding_ranks=None,get_position_embedding_ranks=None,non_loss_data_func=None,store=None,inprocess_call_wrapper: Optional[CallWrapper] = None,
):
File: megatron/training/training.py (L1228-1228)
def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func):
File: megatron/training/training.py (L1257-1267)
losses_reduced = forward_backward_func(forward_step_func=forward_step_func,data_iterator=data_iterator,model=model,num_microbatches=get_num_microbatches(),seq_length=args.seq_length,micro_batch_size=args.micro_batch_size,decoder_seq_length=args.decoder_seq_length,forward_only=False,adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,)
File: megatron/training/training.py (L2508-2517)
loss_dicts = forward_backward_func(forward_step_func=forward_step_func,data_iterator=data_iterator,model=model,num_microbatches=eval_num_microbatches,seq_length=args.seq_length,micro_batch_size=args.micro_batch_size,decoder_seq_length=args.decoder_seq_length,forward_only=True,)
File: pretrain_gpt.py (L121-157)
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Modelreturn_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):vp_stage = get_attr_wrapped_model(model, "vp_stage")tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)timers('batch-generator').stop()with stimer:if args.use_legacy_models:output_tensor = model(tokens, position_ids, attention_mask, labels=labels)else:if return_schedule_plan:assert args.overlap_moe_expert_parallel_comm, \"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"schedule_plan = model.build_schedule_plan(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)return schedule_plan, partial(loss_func, loss_mask, model=model)else:output_tensor = model(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)# [ModelOpt]: model is needed to access ModelOpt distillation lossesreturn output_tensor, partial(loss_func, loss_mask, model=model)
File: megatron/core/pipeline_parallel/schedules.py (L401-406)
if checkpoint_activations_microbatch is None:output_tensor, loss_func = forward_step_func(data_iterator, model)else:output_tensor, loss_func = forward_step_func(data_iterator, model, checkpoint_activations_microbatch)
File: megatron/core/pipeline_parallel/schedules.py (L497-511)
def forward_backward_no_pipelining(*,forward_step_func,data_iterator: Union[Iterator, List[Iterator]],model: Union[torch.nn.Module, List[torch.nn.Module]],num_microbatches: int,seq_length: int, # unusedmicro_batch_size: int, # unuseddecoder_seq_length: Optional[int] = None, # unusedforward_only: bool = False,collect_non_loss_data: bool = False,first_val_step: Optional[bool] = None,adjust_tensor_shapes_fn: Optional[Callable] = None, # unusedpg_collection: Optional[ProcessGroupCollection] = None,
):
File: train_rl.py (L156-215)
def forward_step(data_iterator, model: GPTModel):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Model"""args = get_args()timers = get_timers()timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):(tokens,advantages,old_logprobs,loss_mask,position_ids,ref_logprobs,inference_logprobs,) = next(data_iterator)timers('batch-generator').stop()tokens = tokens.cuda()position_ids = position_ids.cuda()old_logprobs = old_logprobs.cuda()ref_logprobs = ref_logprobs.cuda()advantages = advantages.cuda()inference_logprobs = (inference_logprobs.cuda() if args.rl_inference_logprobs_is_correction else None)with stimer:current_logprobs = get_logprobs(model, tokens, position_ids, None, no_grad=False)loss, kl_term, ratios, entropy_term, truncated_from_above, truncated_from_below = (calculate_grpo_loss(current_logprobs=current_logprobs,old_logprobs=old_logprobs,ref_logprobs=ref_logprobs,advantages=advantages,clamp_eps_lower=args.grpo_clamp_eps_lower,clamp_eps_upper=args.grpo_clamp_eps_upper,kl_beta=args.grpo_kl_beta,entropy_weight=args.grpo_entropy_term_weight,inference_logprobs=inference_logprobs,is_truncation_coef=args.rl_importance_sampling_truncation_coef,))# loss_mask will not be applied to 0th token as we do not have a logprob for it.return loss, partial(loss_func,loss_mask[:, 1:].contiguous(),kl_term,ratios,entropy_term,truncated_from_above,truncated_from_below,)
File: tests/unit_tests/pipeline_parallel/test_schedules.py (L138-147)
def forward_step_func(data_iterator, model):import osrank = int(os.environ['LOCAL_RANK'])dummy_data = torch.ones(1, 4)def loss_func(output_tensor):return rank, {'loss_reduced': rank}return model(dummy_data), loss_func
您询问的是 forward_backward_func 的作用以及为什么需要它。 1
核心作用
forward_backward_func 是一个调度函数,负责协调整个训练过程中的前向传播和反向传播执行顺序,特别是在使用 pipeline 并行时。 2
为什么需要它
1. 支持多种并行策略
get_forward_backward_func() 根据 pipeline 配置自动选择合适的调度策略: 3
-
无 pipeline 并行 (
pipeline_model_parallel_size == 1): 返回forward_backward_no_pipelining,简单地顺序执行所有 microbatch 4 -
有 pipeline 但无 virtual pipeline: 返回
forward_backward_pipelining_without_interleaving,使用标准的 1F1B 调度 5 -
有 virtual pipeline: 返回
forward_backward_pipelining_with_interleaving,使用交错的 1F1B 调度以减少 pipeline bubble 6
2. 统一的接口抽象
通过统一的函数接口,训练代码无需关心底层的并行策略细节。在 train_step() 中,只需调用 forward_backward_func 并传入标准参数: 7
这种设计使得切换并行配置时,训练代码本身不需要修改。
3. 优化 Pipeline Bubble
对于 pipeline 并行,forward_backward_func 实现了 1F1B (one-forward-one-backward) 调度策略,这是一种优化的执行顺序: 8
- Warmup 阶段: 先执行若干个前向传播填充 pipeline
- 1F1B 阶段: 交替执行一个前向和一个反向,保持 pipeline 满载
- Cooldown 阶段: 完成剩余的反向传播
这种调度显著减少了 pipeline bubble,提高了 GPU 利用率。
4. 支持 CUDA Graph 等优化
forward_backward_func 可以被包装以支持额外的优化,例如 Full Iteration CUDA Graph: 9
通过 FullCudaGraphWrapper 包装后,整个前向反向过程可以被 CUDA Graph 捕获和重放,进一步提升性能。 10
实际使用示例
在训练循环中的使用: 11
在评估时也使用相同的接口: 12
在快速入门示例中的使用: 13
设计优势
Notes
这种设计体现了策略模式的思想:将不同的调度算法封装成可互换的函数,通过工厂函数 get_forward_backward_func() 根据配置选择合适的实现。 这使得 Megatron-LM 能够灵活支持从单 GPU 到数千 GPU 的各种训练规模,而用户代码保持简洁统一。 测试代码也验证了这种设计的正确性。 14
Wiki pages you might want to explore:
- Architecture Overview (NVIDIA/Megatron-LM)
Wiki pages you might want to explore:
- Architecture Overview (NVIDIA/Megatron-LM)
Citations
File: megatron/core/pipeline_parallel/schedules.py (L40-132)
def get_forward_backward_func():"""Retrieves the appropriate forward_backward function given theconfiguration of parallel_state.Returns a function that will perform all of the forward andbackward passes of the model given the pipeline model parallelworld size and virtual pipeline model parallel world size in theglobal parallel_state.Note that if using sequence parallelism, the sequence length component ofthe tensor shape is updated to original_sequence_length /tensor_model_parallel_world_size.The function returned takes the following arguments:forward_step_func (required): A function that takes a dataiterator and a model as its arguments and return the model'sforward output and the loss function. The loss function shouldtake one torch.Tensor and return a torch.Tensor of loss and adictionary of string -> torch.Tensor.A third argument, checkpoint_activations_microbatch, indicatesthat the activations for this microbatch should becheckpointed. A None value for this argument indicates thatthe default from the configuration should be used. This isused when thenum_microbatches_with_partial_activation_checkpoints is used.For example:def loss_func(loss_mask, output_tensor):losses = output_tensor.float()loss_mask = loss_mask.view(-1).float()loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()# Reduce loss for logging.averaged_loss = average_losses_across_data_parallel_group([loss])return loss, {'lm loss': averaged_loss[0]}def forward_step(data_iterator, model):data, loss_mask = next(data_iterator)output = model(data)return output, partial(loss_func, loss_mask)forward_backward_func(forward_step_func=forward_step, ...)data_iterator (required): an iterator over the data, will bepassed as is to forward_step_func. Expected to be a list ofiterators in the case of interleaved pipeline parallelism.model (required): the actual model. Expected to be a list of modules in the case of interleavedpipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.num_microbatches (int, required):The number of microbatches to go throughseq_length (int, required): Sequence length of the current global batch. If this is a dual-stacktransformer, this is the encoder's sequence length. This is ignored if variable_seq_lengthsin the config is True. Otherwise, each microbatch in the current global batch size must usethis sequence length.micro_batch_size (int, required): The number of sequences in a microbatch.decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stacktransformer. This is ignored for a single-stack transformer.forward_only (optional, default = False): Perform only the forward stepcollect_non_loss_data (optional, bool, default=False): TODOfirst_val_step (bool, optional): Is the first step of the validation phase. Used byTransformer Engine modules to only update their fp8 weights only on the first validationstep.adjust_tensor_shapes_fn (Callable, optional): A function that adjusts the receive and sendtensor shapes. Only applicable in forward_backward_pipelining_without_interleaving for now.Takes in a list of receive shapes and a list of send shapes and returns the adjustedrespective list of shapes. Thus it is not used in the other forward-backward functionswhich have different shape handling."""pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()if pipeline_model_parallel_size > 1:if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:forward_backward_func = forward_backward_pipelining_with_interleavingelse:forward_backward_func = forward_backward_pipelining_without_interleavingelse:forward_backward_func = forward_backward_no_pipeliningreturn forward_backward_func
File: megatron/core/pipeline_parallel/schedules.py (L497-512)
def forward_backward_no_pipelining(*,forward_step_func,data_iterator: Union[Iterator, List[Iterator]],model: Union[torch.nn.Module, List[torch.nn.Module]],num_microbatches: int,seq_length: int, # unusedmicro_batch_size: int, # unuseddecoder_seq_length: Optional[int] = None, # unusedforward_only: bool = False,collect_non_loss_data: bool = False,first_val_step: Optional[bool] = None,adjust_tensor_shapes_fn: Optional[Callable] = None, # unusedpg_collection: Optional[ProcessGroupCollection] = None,
):"""Run forward and backward passes with no pipeline parallelism"""
File: megatron/core/pipeline_parallel/schedules.py (L1234-1348)
def backward_step_helper_preprocess(virtual_microbatch_id, model_chunk_id):"""Preprocess for backward_step_helper"""# launch grad synchronization (default)if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(virtual_microbatch_id):enable_grad_sync()synchronized_model_chunks.add(model_chunk_id)# pylint: disable=E0606if _is_vp_last_stage(vp_stage=model_chunk_id) and is_pp_last_stage(pp_group):if len(output_tensor_grads[model_chunk_id]) == 0:output_tensor_grads[model_chunk_id].append(None)input_tensor = input_tensors[model_chunk_id].pop(0)output_tensor = output_tensors[model_chunk_id].pop(0)output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)return input_tensor, output_tensor, output_tensor_graddef backward_step_helper_postprocess(virtual_microbatch_id):"""Postprocess for backward_step_helper"""# launch grad synchronization (custom grad sync)# Note: Asynchronous communication tends to slow down compute.# To reduce idling from mismatched microbatch times, we launch# asynchronous communication at the same time across the# pipeline-parallel group.if config.grad_sync_func is not None:grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rankif grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(grad_sync_virtual_microbatch_id):grad_sync_chunk_id = get_model_chunk_id(grad_sync_virtual_microbatch_id, forward=False)enable_grad_sync()config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())synchronized_model_chunks.add(grad_sync_chunk_id)disable_grad_sync()def backward_step_helper(virtual_microbatch_id):"""Helper method to run backward step with model split into chunks"""nonlocal output_tensor_gradsmodel_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)input_tensor, output_tensor, output_tensor_grad = backward_step_helper_preprocess(virtual_microbatch_id, model_chunk_id)input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)backward_step_helper_postprocess(virtual_microbatch_id)return input_tensor_graddef forward_backward_helper_wrapper(f_virtual_microbatch_id=None,b_virtual_microbatch_id=None,pre_forward=None,pre_backward=None,post_forward=None,post_backward=None,checkpoint_activations_microbatch=None,):"""wrap forward_helper, backward_helper, and combined_forward_backward_helper in a unified way"""if config.overlap_moe_expert_parallel_comm and not forward_only: # Combined 1F1B pathreturn combined_1f1b_schedule_for_interleaved_pipelining(config,forward_step_func,data_iterator,model,num_microbatches,forward_data_store,forward_step_helper_preprocess,forward_step_helper_postprocess,backward_step_helper_preprocess,backward_step_helper_postprocess,get_microbatch_id_in_model_chunk,get_model_chunk_id,partial(check_first_val_step, first_val_step, forward_only),is_first_microbatch_for_model_chunk,collect_non_loss_data,f_virtual_microbatch_id=f_virtual_microbatch_id,b_virtual_microbatch_id=b_virtual_microbatch_id,pre_forward=pre_forward,pre_backward=pre_backward,post_forward=post_forward,post_backward=post_backward,)else: # Conventional interleaved 1F1B pathforward_output_tensor = Nonebackward_input_tensor_grad = None# forward passif f_virtual_microbatch_id is not None:forward_model_chunk_id = get_model_chunk_id(f_virtual_microbatch_id, forward=True)if pre_forward is not None:pre_forward()forward_output_tensor = forward_step_helper(f_virtual_microbatch_id, checkpoint_activations_microbatch)if post_forward is not None:forward_output_tensor = post_forward(forward_output_tensor)# Backward pass.if b_virtual_microbatch_id is not None:backward_model_chunk_id = get_model_chunk_id(b_virtual_microbatch_id, forward=False)if pre_backward is not None:pre_backward()backward_input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)if post_backward is not None:backward_input_tensor_grad = post_backward(backward_input_tensor_grad)return forward_output_tensor, backward_input_tensor_grad
File: megatron/training/training.py (L1257-1267)
losses_reduced = forward_backward_func(forward_step_func=forward_step_func,data_iterator=data_iterator,model=model,num_microbatches=get_num_microbatches(),seq_length=args.seq_length,micro_batch_size=args.micro_batch_size,decoder_seq_length=args.decoder_seq_length,forward_only=False,adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,)
File: megatron/training/training.py (L2103-2105)
forward_backward_func = get_forward_backward_func()if args.enable_cuda_graph and args.cuda_graph_scope=="full_iteration":forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)
File: megatron/training/training.py (L2268-2270)
) = train_step(forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func)
File: megatron/training/training.py (L2508-2517)
loss_dicts = forward_backward_func(forward_step_func=forward_step_func,data_iterator=data_iterator,model=model,num_microbatches=eval_num_microbatches,seq_length=args.seq_length,micro_batch_size=args.micro_batch_size,decoder_seq_length=args.decoder_seq_length,forward_only=True,)
File: megatron/core/full_cuda_graph.py (L94-104)
class FullCudaGraphWrapper:"""Wrapper class to enable FullIterationCUDAgraph."""curr_iteration = {'training': 0, 'validation': 0}cuda_graph = {'training': None, 'validation': None}result = {'training': None, 'validation': None}def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1):self.forward_backward_func = forward_backward_funcself.static_loader = StaticBufferLoader()self.cuda_graph_warmup_steps = cuda_graph_warmup_steps
File: megatron/core/QuickStart.md (L223-237)
forward_backward_func = get_forward_backward_func()# Running the model for 5 iterationsfor _ in range(5):optim.zero_grad()losses_reduced = forward_backward_func(forward_step_func=forward_step_func,data_iterator=train_iterator,model=gpt_model,num_microbatches=1,seq_length=64,micro_batch_size=8,decoder_seq_length=64,forward_only=False)
File: tests/unit_tests/pipeline_parallel/test_schedules.py (L37-66)
def test_get_forward_backward_func():Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipeliningUtils.destroy_model_parallel()Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)assert (schedule.get_forward_backward_func()== schedule.forward_backward_pipelining_without_interleaving)Utils.destroy_model_parallel()Utils.initialize_model_parallel(tensor_model_parallel_size=2,pipeline_model_parallel_size=4,virtual_pipeline_model_parallel_size=2,)assert (schedule.get_forward_backward_func()== schedule.forward_backward_pipelining_with_interleaving)Utils.destroy_model_parallel()Utils.initialize_model_parallel(tensor_model_parallel_size=2,pipeline_model_parallel_size=2,virtual_pipeline_model_parallel_size=4,)assert (schedule.get_forward_backward_func()== schedule.forward_backward_pipelining_with_interleaving)Utils.destroy_model_parallel()
