当前位置: 首页 > news >正文

【大模型训练】roll 调用megatron 计算损失函数有,会用到partial

forward_step_func=partial(self.inner_forward_step, forward_func)这句话什么意思

好的,我们来深入探讨一下 Python 中回调(Callback)的原理,以及在 roll 框架这个具体场景下,回调函数是如何被“触发”的。

Python 回调的基本原理

在 Python 中,函数是一等公民(First-class Citizens)。这意味着函数可以:

  1. 被赋值给一个变量。
  2. 作为参数传递给另一个函数。
  3. 作为另一个函数的返回值。

“回调”正是利用了第二点特性。

核心思想:你定义一个函数 A,然后将另一个函数 B 作为参数传递给 A。在 A 执行过程中的某个特定时间点,A 会去调用(“回调”)你传给它的那个函数 B

一个简单的例子:

# 定义一个回调函数,它规定了如何处理一个数字
def square_callback(number):result = number * numberprint(f"Callback executed: The square of {number} is {result}")return result# 定义一个主函数,它接收一个数据列表和一个回调函数
def process_data(data_list, callback_function):print("Main function started, processing data...")results = []for item in data_list:# 在这里,主函数调用了(“回调”)传入的函数processed_item = callback_function(item)results.append(processed_item)print("Main function finished.")return results# --- 主程序 ---
my_data = [1, 2, 3, 4]# 调用主函数,并将 square_callback 作为参数传递
final_results = process_data(my_data, square_callback)print(f"Final results: {final_results}")

输出:

Main function started, processing data...
Callback executed: The square of 1 is 1
Callback executed: The square of 2 is 4
Callback executed: The square of 3 is 9
Callback executed: The square of 4 is 16
Main function finished.
Final results: [1, 4, 9, 16]

在这个例子中:

  • process_data 是主调函数(Caller),类似于 strategy.forward_step
  • square_callback 是回调函数(Callback),类似于 forward_func_log_probs
  • process_data 只负责遍历数据这个通用流程,它不知道具体要对数据做什么运算。
  • square_callback 定义了具体的运算(求平方)。
  • process_data 运行到 processed_item = callback_function(item) 这一行时,回调就被触发了。实际上,callback_function 此时就是 square_callback 的一个引用。

roll 框架中的具体触发过程

现在,我们把这个原理应用到 roll 的代码中。

参与者:

  1. 主调函数 (Caller): strategy.forward_step
  2. 回调函数 (Callback): self.forward_func_log_probs (即 ActorWorker 的一个方法)
  3. 触发点: inner_forward_step 内部的 return output_tensor, partial(loss_func, data)forward_backward_func 的内部实现。

让我们追踪一下调用的路径,看看回调是如何被触发的。

路径 1: compute_log_probs -> forward_step

# ActorWorker.py
def compute_log_probs(self, data: DataProto):# ...# 这里,self.forward_func_log_probs 被当作一个值(一个可调用对象)传递给了 forward_stepresults = self.strategy.forward_step(batch=data, forward_func=self.forward_func_log_probs # <--- 传递回调)# ...

forward_step 的定义中,它接收了这个回调函数,并将其命名为 forward_func

# MegatronInferStrategy.py (或类似的 Strategy 类)
def forward_step(self, batch: DataProto, forward_func: Callable): # <--- 接收回调# ...# 它将 forward_func 进一步传递下去losses_reduced = self.forward_backward_func(forward_step_func=partial(self.inner_forward_step, forward_func), # <--- 再次传递# ...)# ...

路径 2: forward_step -> forward_backward_func -> inner_forward_step

forward_backward_func 是 Megatron-LM 框架中的一个函数,它封装了处理流水线并行(Pipeline Parallelism)和微批次(Micro-batching)的复杂逻辑。它的核心作用是循环调用你提供给它的 forward_step_func

在我们的例子中,forward_step_funcpartial(self.inner_forward_step, forward_func)。这意味着 forward_backward_func 在其内部循环中会执行类似这样的代码:

# forward_backward_func 的伪代码
def forward_backward_func(forward_step_func, data_iterator, ...):all_outputs = []# 循环处理每个微批次for i in range(num_microbatches):# *** 触发点 1 ***# 调用我们通过 partial 构造的函数# 这会执行 inner_forward_step(data_iterator, model)output_tensor, process_fn = forward_step_func(data_iterator, model)# *** 触发点 2 ***# process_fn 就是 partial(loss_func, data)# 也就是 partial(forward_func_log_probs, data)# 这里,回调函数 forward_func_log_probs 被真正地执行了!loss, metrics = process_fn(output_tensor)all_outputs.append(metrics)return all_outputs

现在我们看 inner_forward_step 做了什么:

# MegatronInferStrategy.py
def inner_forward_step(self, loss_func, data_iterator, model): # loss_func 就是 forward_func_log_probs# ... (准备数据)# 1. 执行模型前向传播output_tensor = model(input_ids=..., attention_mask=...)# 2. 准备回调# 它并不直接调用 loss_func,而是返回一个 partial 对象# partial(loss_func, data) 的意思是:# “创建一个新的、可调用的对象,它等价于调用 loss_func(data, ...),只是第一个参数 data 已经预先填好了”return output_tensor, partial(loss_func, data)

总结一下触发过程:

  1. compute_log_probsforward_func_log_probs 这个方法对象作为参数传给了 forward_step
  2. forward_step 将这个方法对象进一步传给了底层的 forward_backward_func
  3. forward_backward_func 在其微批次循环中,调用了 inner_forward_step
  4. inner_forward_step 执行了模型的前向传播,得到 output_tensor
  5. inner_forward_step 并不直接调用 forward_func_log_probs。相反,它创建并返回了一个 partial 对象,这个对象“包裹”了 forward_func_log_probs 和当前的微批次数据 data
  6. forward_backward_func 拿到了 output_tensor 和这个 partial 对象(我们叫它 process_fn)。
  7. 真正的触发点forward_backward_func 调用 process_fn(output_tensor)。由于 process_fn 是一个包裹了 forward_func_log_probspartial 对象,这一步就等价于执行:
    forward_func_log_probs(data, output_tensor)
    此时,ActorWorker 中定义的回调函数就被成功触发了,它拿到了模型输出 output_tensor 和对应的输入数据 data,开始执行它自己的特定计算(即调用 op_compute_log_probs)。

这个稍微有些绕的设计,是为了将 Megatron-LM 自身的复杂循环逻辑与 roll 框架的上层业务逻辑解耦。forward_backward_func 只需要知道它需要一个能返回 (tensor, callable) 的函数即可,而不需要关心那个 callable 到底是什么。

这里为了适配megatron 中的 forward_step_func, 他forward_step_func只需要传入data_iter, 与model.
但是他inner_forward_step(self, loss_func, data_iterator, model):三个函数,所以要
forward_step_func=partial(self.inner_forward_step, forward_func)

def get_forward_backward_func():"""Retrieves the appropriate forward_backward function given theconfiguration of parallel_state.Returns a function that will perform all of the forward andbackward passes of the model given the pipeline model parallelworld size and virtual pipeline model parallel world size in theglobal parallel_state.Note that if using sequence parallelism, the sequence length component ofthe tensor shape is updated to original_sequence_length /tensor_model_parallel_world_size.The function returned takes the following arguments:forward_step_func (required): A function that takes a dataiterator and a model as its arguments and return the model'sforward output and the loss function. The loss function shouldtake one torch.Tensor and return a torch.Tensor of loss and adictionary of string -> torch.Tensor.A third argument, checkpoint_activations_microbatch, indicatesthat the activations for this microbatch should becheckpointed. A None value for this argument indicates thatthe default from the configuration should be used. This isused when thenum_microbatches_with_partial_activation_checkpoints is used.For example:def loss_func(loss_mask, output_tensor):losses = output_tensor.float()loss_mask = loss_mask.view(-1).float()loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()# Reduce loss for logging.averaged_loss = average_losses_across_data_parallel_group([loss])return loss, {'lm loss': averaged_loss[0]}def forward_step(data_iterator, model):data, loss_mask = next(data_iterator)output = model(data)return output, partial(loss_func, loss_mask)forward_backward_func(forward_step_func=forward_step, ...)

您询问的是 forward_backward_funcforward_step_func 的使用方式,以及它返回的 partial(loss_func, loss_mask, model=model) 的用途。

forward_step_func 的使用方式

forward_step_func 是用户定义的前向传播函数,在 forward_backward_func 内部被调用。 1

调用位置

forward_step() 函数中,forward_step_func 被调用: 2

具体调用方式:

output_tensor, loss_func = forward_step_func(data_iterator, model)

这里 forward_step_func 接收两个参数:

  1. data_iterator: 数据迭代器
  2. model: 模型实例

返回两个值:

  1. output_tensor: 模型的输出(通常是损失张量)
  2. loss_func: 一个偏函数,用于后续的损失处理

partial(loss_func, ...) 的用途

返回的 partial(loss_func, loss_mask, model=model) 被用于 forward_step_calc_loss() 函数中。 3

使用流程

  1. 接收 loss_func: 4

  2. 调用 loss_func: 在 forward_step_calc_loss() 中,如果是最后一个 pipeline stage,会调用这个 loss_func: 5

  3. 处理返回值: loss_func 被调用时传入 output_tensor,返回处理后的损失和元数据。

实际示例

pretrain_gpt.py 中可以看到完整的使用示例: 6

这里:

  1. forward_step_func 调用模型得到 output_tensor
  2. 返回 partial(loss_func, loss_mask, model=model),将 loss_maskmodel 预先绑定
  3. 后续在 forward_step_calc_loss() 中,这个偏函数被调用时只需传入 output_tensor,就能完成损失计算

为什么使用 partial

使用 partial 的好处是:

  • 延迟执行: 在 forward_step_func 中不立即计算最终损失,而是返回一个函数
  • 参数绑定: 将 loss_maskmodel 等上下文信息提前绑定,后续只需传入 output_tensor
  • 灵活性: 允许在不同的 pipeline stage 有不同的损失处理逻辑

Notes

这种设计模式在所有训练脚本中都是一致的,包括 pretrain_mamba.py 7 , pretrain_t5.py 8 等。这种统一的接口设计使得 Megatron-LM 能够支持不同类型的模型和训练任务。

Wiki pages you might want to explore:

  • Architecture Overview (NVIDIA/Megatron-LM)

Citations

File: megatron/core/pipeline_parallel/schedules.py (L199-287)

def forward_step_calc_loss(model,output_tensor,loss_func,config,vp_stage,collect_non_loss_data,num_microbatches,forward_data_store,cp_group_size=None,is_last_stage=None,
):"""Calculate the loss and number of tokens for forward_step()"""from megatron.core.transformer.multi_token_prediction import MTPLossAutoScalermodel_vp_stage = getattr(model, "vp_stage", None)if vp_stage is not None and model_vp_stage is not None:assert (vp_stage == model_vp_stage), f"vp_stage ({vp_stage}) doesn't match model_vp_stage ({model_vp_stage})"if cp_group_size is None and is_last_stage is None:# fallback to parallel statecp_group_size = parallel_state.get_context_parallel_world_size()is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage)else:assert (cp_group_size is not None and is_last_stage is not None), "cp_group_size and is_last_stage must be provided"num_tokens = torch.tensor(0, dtype=torch.int)if is_last_stage:if not collect_non_loss_data:outputs = loss_func(output_tensor)if len(outputs) == 3:output_tensor, num_tokens, loss_reduced = outputsif not config.calculate_per_token_loss:# Protect against division by zero when all tokens are masked#   in a microbatch.output_tensor /= torch.clamp(num_tokens, min=1)output_tensor /= num_microbatcheselse:# preserve legacy loss averaging behavior (ie, over the number of microbatches)assert len(outputs) == 2output_tensor, loss_reduced = outputsoutput_tensor *= cp_group_sizeoutput_tensor /= num_microbatchesforward_data_store.append(loss_reduced)else:data = loss_func(output_tensor, non_loss_data=True)forward_data_store.append(data)if config.timers is not None:config.timers('forward-compute').stop()# Set the loss scale for the auxiliary loss of the MoE layer.# Since we use a trick to do backward on the auxiliary loss, we need to set the scale# explicitly.if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:# Calculate the loss scale based on the grad_scale_func if available, else default to 1.loss_scale = (config.grad_scale_func(torch.ones(1, device=output_tensor.device))if config.grad_scale_func is not Noneelse torch.ones(1, device=output_tensor.device))# Set the loss scaleif config.calculate_per_token_loss:MoEAuxLossAutoScaler.set_loss_scale(loss_scale)else:MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)# Set the loss scale for Multi-Token Prediction (MTP) loss.if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:# Calculate the loss scale based on the grad_scale_func if available, else default to 1.loss_scale = (config.grad_scale_func(torch.ones(1, device=output_tensor.device))if config.grad_scale_func is not Noneelse torch.ones(1, device=output_tensor.device))# Set the loss scaleif config.calculate_per_token_loss:MTPLossAutoScaler.set_loss_scale(loss_scale)else:MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)return output_tensor, num_tokens

File: megatron/core/pipeline_parallel/schedules.py (L290-422)

def forward_step(forward_step_func,data_iterator,model,num_microbatches,input_tensor,forward_data_store,config,cp_group_size,collect_non_loss_data=False,checkpoint_activations_microbatch=None,is_first_microbatch=False,current_microbatch=None,vp_stage=None,is_last_stage=True,
):"""Forward step for passed-in model.If it is the first stage, the input tensor is obtained from the data_iterator.Otherwise, the passed-in input_tensor is used.Args:forward_step_func (callable):The forward step function for the model that takes thedata iterator as the first argument, and model as the second.This user's forward step is expected to output a tuple of two elements:1. The output object from the forward step. This output object needs to be atensor or some kind of collection of tensors. The only hard requirementfor this object is that it needs to be acceptible as input into the secondfunction.2. A function to reduce (optionally) the output from the forward step. Thiscould be a reduction over the loss from the model, it could be a function thatgrabs the output from the model and reformats, it could be a function that justpasses through the model output. This function must have one of the followingpatterns, and depending on the pattern different things happen internally:a. A tuple of reduced loss and some other data. Note that in this casethe first argument is divided by the number of global microbatches,assuming it is a loss, so that the loss is stable as a function ofthe number of devices the step is split across.b. A triple of reduced loss, number of tokens, and some other data. Thisis similar to case (a), but the loss is further averaged across thenumber of tokens in the batch. If the user is not already averagingacross the number of tokens, this pattern is useful to use.c. Any arbitrary data the user wants (eg a dictionary of tensors, a listof tensors, etc in the case of inference). To trigger case 3 you needto specify `collect_non_loss_data=True` and you may also want tospecify `forward_only=True` in the call to the parent forward_backwardfunction.data_iterator (iterator):The data iterator.model (nn.Module):The model to perform the forward step on.num_microbatches (int):The number of microbatches.input_tensor (Tensor or list[Tensor]):The input tensor(s) for the forward step.forward_data_store (list):The list to store the forward data. If you go down path 2.a or2.b for the return of your forward reduction function then this will store only thefinal dimension of the output, for example the metadata output by the loss function.If you go down the path of 2.c then this will store the entire output of the forwardreduction function applied to the model output.config (object):The configuration object.collect_non_loss_data (bool, optional):Whether to collect non-loss data. Defaults to False.This is the path to use if you want to collect arbitrary output from the model forward,such as with inference use cases. Defaults to False.checkpoint_activations_microbatch (int, optional):The microbatch to checkpoint activations.Defaults to None.is_first_microbatch (bool, optional):Whether it is the first microbatch. Defaults to False.current_microbatch (int, optional):The current microbatch. Defaults to None.vp_stage (int, optional):The virtual pipeline stage. Defaults to None.is_last_stage (bool, optional):Whether it is the last stage. Defaults to True.Also considering virtual stages.In case of PP/VPP, is_last_stage/is_vp_last_stage.Returns:Tensor or list[Tensor]: The output object(s) from the forward step.Tensor: The number of tokens."""from megatron.core.transformer.multi_token_prediction import MTPLossAutoScalerif config.timers is not None:config.timers('forward-compute', log_level=2).start()if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):model.set_is_first_microbatch()if current_microbatch is not None:set_current_microbatch(model, current_microbatch)unwrap_output_tensor = Falseif not isinstance(input_tensor, list):input_tensor = [input_tensor]unwrap_output_tensor = Trueset_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")set_input_tensor(input_tensor)if config.enable_autocast:context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)else:context_manager = contextlib.nullcontext()with context_manager:if checkpoint_activations_microbatch is None:output_tensor, loss_func = forward_step_func(data_iterator, model)else:output_tensor, loss_func = forward_step_func(data_iterator, model, checkpoint_activations_microbatch)output_tensor, num_tokens = forward_step_calc_loss(model,output_tensor,loss_func,config,vp_stage,collect_non_loss_data,num_microbatches,forward_data_store,cp_group_size,is_last_stage,)if unwrap_output_tensor:return output_tensor, num_tokensreturn [output_tensor], num_tokens

File: pretrain_gpt.py (L121-157)

def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Modelreturn_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):vp_stage = get_attr_wrapped_model(model, "vp_stage")tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)timers('batch-generator').stop()with stimer:if args.use_legacy_models:output_tensor = model(tokens, position_ids, attention_mask, labels=labels)else:if return_schedule_plan:assert args.overlap_moe_expert_parallel_comm, \"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"schedule_plan = model.build_schedule_plan(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)return schedule_plan, partial(loss_func, loss_mask, model=model)else:output_tensor = model(tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask)# [ModelOpt]: model is needed to access ModelOpt distillation lossesreturn output_tensor, partial(loss_func, loss_mask, model=model)

File: pretrain_mamba.py (L131-154)

def forward_step(data_iterator, model: MambaModel):"""Forward training step.Args:data_iterator : Input data iteratormodel (MambaModel): The GPT Model"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):vp_stage = get_attr_wrapped_model(model, "vp_stage")tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)timers('batch-generator').stop()with stimer:output_tensor = model(tokens, position_ids, attention_mask,labels=labels)# [ModelOpt]: model is needed to access ModelOpt distillation lossesreturn output_tensor, partial(loss_func, loss_mask, model=model)

File: pretrain_t5.py (L174-198)

def forward_step(data_iterator, model: T5Model):"""Forward training step.Args:data_iterator : Input data iteratormodel (T5Model): The T5 Model"""args = get_args()timers = get_timers()# Get the batch.timers('batch generator', log_level=2).start()use_local = args.transformer_impl == "local"tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = get_batch(data_iterator, use_local)timers('batch generator').stop()# Forward model lm_labelsoutput_tensor = model(tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, lm_labels=lm_labels)return output_tensor, partial(loss_func, loss_mask)
http://www.dtcms.com/a/596377.html

相关文章:

  • 使用nestjs/cli创建nest.js新项目
  • 广州外贸网站建设公司平面设计主要做什么工资多少
  • 广东省建设工程交易中心网站网站关键词不稳定
  • 组建网站需多少钱微信网站模板
  • jfinal 支持mysql的json字段类型解决方案
  • Excel处理控件Aspose.Cells教程:如何使用C#在Excel中添加、编辑和更新切片器
  • Java 在 Excel 文件中添加或删除分节符
  • 电子电气架构 --- 车载OTA功能
  • Chrome HSTS(HTTP Strict Transport Security)
  • 【项目亮点】基于EasyExcel + 线程池解决POI文件导出时的内存溢出及超时问题
  • 【C++】链表算法习题
  • 搭建智能问答系统需要什么文档解析工具?
  • 【C++】(以及大多数编程语言)中常见的 六种基本位运算操作
  • (129页PPT)罗兰贝格银行风险预警管理体系规划(附下载方式)
  • 建设银行网站可以更改个人电话网址大全域名解析
  • 增删查改(其一) —— insert插入 与 select条件查询
  • JuiceSSH+cpolar解锁手机远程Linux新姿势,无需公网IP,固定地址稳定用
  • 传统生产制造企业手写单据数字化落地:旗讯 OCR 的技术实现与系统对接方案
  • 如何添加网站白名单广州建设网站的公司
  • nnUNet 训练与推理命令操作记录
  • 【C#】从一次异步锁逐渐展开浅谈服务器架构解决重复编码问题,我与AI的一次深度讨论得出的一些解决方案
  • PKHeX 宝可梦存档编辑工具 用户可自由修改宝可梦属性、技能、道具、图鉴完成度等信息
  • 深度解析:环形链表——手撕面试经典题
  • elasticsearch集群访问中的通信问题
  • 西安模板网站建设套餐佛山做网站费用
  • 什么是RKNN?
  • 《智元启示录》升级说明:从「AI 思考集」到「AI 决策内参」
  • Ansible 基础配置与负载均衡部署实践
  • 融合先验文本与解剖学知识的多模态回归网络用于舌鳞状细胞癌浸润深度的自动预测|文献速递-文献分享
  • 【负载均衡】LVS DR模式详解