当前位置: 首页 > news >正文

【大模型训练】 roll 权重更新 过程

您在询问为什么日志中显示的是 model_update_end_onload 而不是 model_update_end_offload

日志含义解释

这个日志来自 state_offload_manger 上下文管理器,它记录了 model_update 过程中的内存状态。 1

日志时间点

state_offload_manger 在三个关键时间点记录内存:

  1. start_offload: 开始时的状态(模型在 CPU) 2
  2. start_onload: 加载到 GPU 后的状态 3
  3. end_onload: 执行完成时的状态(模型仍在 GPU) 4
  4. end_offload: 卸载回 CPU 后的状态 5

Model Update 的特殊性

Worker.start_model_update() 中,调用 state_offload_manger 时使用了特殊的 load_kwargs: 6

关键配置load_kwargs={"include": [OffloadStateType.model_params]}

这意味着:

  1. 只加载 model_params(模型权重)到 GPU
  2. 执行 model_update(收集权重并广播) 7
  3. 执行完成后,模型权重仍在 GPU 上(因为 state_offload_manger 默认在 yield 后才 offload)

为什么是 end_onload

日志显示 model_update_end_onload 是因为:

  • 这是在 yield 语句之后、offload_states() 调用之前记录的 8
  • 此时模型权重仍然在 GPU 上(onload 状态)
  • 内存占用 15.7GB 正是模型权重的大小

完整流程

PipelineWorker.start_model_update()state_offload_mangerMegatronStrategy调用 start_model_update()进入上下文管理器记录 start_offload (CPU)load_states(model_params)记录 start_onload (GPU, 15.7GB)执行 model_update()收集权重并广播返回记录 end_onload (GPU, 15.7GB) ← 您看到的日志offload_states()记录 end_offload (CPU)退出上下文管理器PipelineWorker.start_model_update()state_offload_mangerMegatronStrategy

后续的 Offload

虽然日志显示 end_onload,但 state_offload_manger 会在之后立即执行 offload: 9

只是这个 end_offload 的日志可能在您提供的日志片段中没有显示出来。

Notes

  • model_update_end_onload 表示 model_update 执行完成时,模型权重仍在 GPU 上
  • 这是正常的,因为 state_offload_manger 先记录 end_onload,然后才执行 offload
  • 15.7GB 的内存占用证明模型权重已成功加载到 GPU
  • 之后会有 end_offload 日志(内存应该降到接近 0),只是您的日志片段中没有包含

Wiki pages you might want to explore:

  • State Offloading and Memory Management (alibaba/ROLL)

Wiki pages you might want to explore:

  • Pipeline Configuration (alibaba/ROLL)
  • State Offloading and Memory Management (alibaba/ROLL)

Citations

File: roll/utils/context_managers.py (L140-186)

@contextmanager
def state_offload_manger(strategy, metrics: Dict, metric_infix: str, is_offload_states=True, load_kwargs={}):"""strategy.load_states()strategy.offload_states()为metrics埋点"""os.environ["roll_EXEC_FUNC_NAME"] = metric_infixwith Timer(name=f"{metric_infix}_total") as timer, local_profiler():with Timer(name=f"{metric_infix}_onload") as onload_timer, profiling.profile("load_states"):for device_id in range(current_platform.device_count()):current_platform.reset_max_memory_allocated(device_id)current_platform.reset_max_memory_cached(device_id)current_platform.reset_peak_memory_stats(device_id)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/offload"))log_gpu_memory_usage(head=f"{metric_infix}_start_offload", logger=logger, rank=None)strategy.load_states(**load_kwargs)if load_kwargs.get("include", None) is not None:strategy.offload_states(**get_load_exclude_kwargs(load_kwargs))log_gpu_memory_usage(head=f"{metric_infix}_start_onload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/onload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "start"))with Timer(name=f"{metric_infix}_execute") as execute_timer, profiling.profile("execute"):yieldwith Timer(name=f"{metric_infix}_offload") as offload_timer, profiling.profile("offload_states"):metrics.update(_get_gpu_memory_metrics(metric_infix, "end/onload", with_max_frac=True))log_gpu_memory_usage(head=f"{metric_infix}_end_onload", logger=logger, rank=None)if is_offload_states:current_platform.clear_cublas_workspaces()strategy.offload_states()log_gpu_memory_usage(head=f"{metric_infix}_end_offload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "end/offload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "end"))metrics[f"time/{metric_infix}/total"] = timer.lastif is_roll_debug_mode():metrics[f"time/{metric_infix}/execute"] = execute_timer.lastmetrics[f"time/{metric_infix}/onload"] = onload_timer.lastmetrics[f"time/{metric_infix}/offload"] = offload_timer.lastdel os.environ["roll_EXEC_FUNC_NAME"]

File: roll/distributed/executor/worker.py (L176-190)

    def start_model_update(self, *args, **kwargs):metrics = {}if getattr(self, "strategy", None) is not None:with state_offload_manger(strategy=self.strategy,metrics=metrics,metric_infix=f"{self.cluster_name}/model_update",load_kwargs={"include": [OffloadStateType.model_params]},):exec_metrics: Dict = self.strategy.model_update(*args, **kwargs)metric_prefix = f"time/{self.cluster_name}/model_update"metrics.update({f"{metric_prefix}/{k}": v for k, v in exec_metrics.items()})else:self.logger.warning("worker has not strategy")

File: roll/distributed/strategy/megatron_strategy.py (L494-534)

    def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):comm_plan = self.model_update_comm_plan[model_update_name][self.worker.rank_info.pp_rank]broadcast_time_cost = 0with Timer("model_update_total") as timer_total:for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024):ray.get(self.barrier.wait.remote())refs = []with Timer("broadcast") as timer_broadcast:for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(model_update_name=model_update_name,meta_infos=meta_infos, buffer=buffer, ranks_in_worker=[p2p_tgt_device["device"]["rank"]])refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,meta_infos=meta_infos,bucket_size=buffer.numel() * buffer.element_size(),)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)ray.get(self.barrier.wait.remote())broadcast_time_cost += timer_broadcast.lastmetrics = {"all_gather": timer_total.last - broadcast_time_cost,"broadcast": broadcast_time_cost,}return metrics

您在询问权重更新过程中 actor_infer 接收端的状态变化。

Actor Infer 的状态变化流程

根据您的日志和代码,actor_infer 使用 SGLang 策略,在权重更新过程中经历以下状态:

1. 初始状态(Offload)

在生成开始前,actor_infer 处于 offload 状态: 1

actor_infer/generate_start_offload, memory allocated (GB): 0.0

此时模型权重不在 GPU 上(使用 load_format="dummy" 初始化)。

2. 接收权重(Model Update)

actor_train 执行 model_update 时,会通过以下路径将权重发送到 actor_infer

发送端(actor_train): 2

  1. 收集权重到 bucket(all_gather_weights_as_hf_bucket
  2. 通过 P2P 或 broadcast 发送到目标 worker

接收端(actor_infer with SGLang): 3

SGLang 的 broadcast_bucket 方法会:

  1. 创建空 buffer 接收数据 4
  2. 通过 collective.broadcast 接收权重 5
  3. 调用 update_parameter_in_bucket 更新模型参数 6

3. 加载权重到 GPU(Load States)

权重接收后,在生成开始时调用 load_states(): 7

这会触发 SGLang 的 resume_memory_occupation(),将模型加载到 GPU。

4. 生成状态(Onload)

模型加载完成后,actor_infer 进入 onload 状态:

actor_infer/generate_start_onload, memory allocated (GB): 应该显示模型大小

但您的日志显示仍为 0.0GB,说明权重没有成功加载!

问题诊断

为什么 actor_infer 内存仍为 0?

从您的日志 Broadcast time: 0.00s 可以看出,broadcast 没有执行

可能的原因:

  1. 通信组未正确设置: 8

    检查日志中是否有 "warmup setup_collective_group" 消息。如果没有,说明 SGLang 没有加入通信组。

  2. Broadcast 条件不满足: 9

    只有 tp_rank == 0 && cp_rank == 0 && dp_rank == 0 的 rank 才会触发 broadcast。

  3. broadcast_tgt_devices 为空: 10

    如果配置中没有指定 broadcast 目标设备,就不会执行广播。

正常的状态变化应该是

actor_trainactor_infer (SGLang)初始状态: offload (0 GB)model_update 开始all_gather_weights_as_hf_bucketbroadcast_bucket (权重数据)collective.broadcast 接收update_parameter_in_bucketmodel.load_weights权重已更新,但仍在 CPU生成开始load_states()resume_memory_occupation()onload 状态 (~15 GB)开始生成actor_trainactor_infer (SGLang)

解决建议

  1. 检查初始化日志

    • 搜索 "setup_collective_group" 确认通信组是否建立
    • 搜索 "warmup setup_collective_group" 确认 SGLang 是否加入
  2. 验证通信计划配置
    确认 pipeline 配置中 model_update_comm_plan 包含了 actor_infer 的设备信息

  3. 临时解决方案
    使用 load_format: autoactor_infer 直接加载模型,绕过 model_update:

    actor_infer:strategy_args:strategy_config:load_format: auto
    

Notes

  • SGLang 的权重更新依赖于正确的通信组设置 11
  • load_states() 只负责将已接收的权重加载到 GPU,不负责接收权重本身 7
  • 如果 Broadcast time: 0.00s,说明权重根本没有发送,actor_infer 自然无法接收
  • vLLM 和 SGLang 的权重更新机制类似,都通过 WorkerHelper 实现 12

Wiki pages you might want to explore:

  • Strategy Pattern and Backend Abstraction (alibaba/ROLL)
  • State Offloading and Memory Management (alibaba/ROLL)

Wiki pages you might want to explore:

  • Strategy Pattern and Backend Abstraction (alibaba/ROLL)
  • State Offloading and Memory Management (alibaba/ROLL)

Citations

File: roll/pipeline/rlvr/rlvr_pipeline.py (L370-370)

                    self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info))

File: roll/distributed/strategy/megatron_strategy.py (L498-528)

            for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024):ray.get(self.barrier.wait.remote())refs = []with Timer("broadcast") as timer_broadcast:for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(model_update_name=model_update_name,meta_infos=meta_infos, buffer=buffer, ranks_in_worker=[p2p_tgt_device["device"]["rank"]])refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,meta_infos=meta_infos,bucket_size=buffer.numel() * buffer.element_size(),)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)ray.get(self.barrier.wait.remote())broadcast_time_cost += timer_broadcast.last

File: roll/third_party/sglang/v052_patch/model_runner.py (L138-162)

    def setup_collective_group(self, comm_plan, backend, rank_in_cluster):self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {})rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster,rank_in_worker=dist.get_rank())if rank is None:logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}")return True, "Succeeded to setup_collective_group."group_name = comm_plan_args["group_name"]master_addr = comm_plan_args["master_addr"]master_port = comm_plan_args["master_port"]world_size = len(comm_plan_args["tgt_devices"]) + 1src_pp_rank = comm_plan_args["src_pp_rank"]collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name,master_addr=master_addr, master_port=master_port)# A small all_reduce for warmup.collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)self.model_update_comm_plan[src_pp_rank] = dict(rank=rank,world_size=world_size,src_pp_rank=src_pp_rank,group_name=group_name,comm_plan=comm_plan,comm_plan_args=comm_plan_args)logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")return True, "Succeeded to setup_collective_group."

File: roll/third_party/sglang/v052_patch/model_runner.py (L164-172)

    def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size):if src_pp_rank not in self.model_update_comm_plan:return True, "Succeeded to broadcast_bucket."comm_plan = self.model_update_comm_plan[src_pp_rank]buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type)collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])self.update_parameter_in_bucket(meta_infos, buffer, [dist.get_rank()])return True, "Succeeded to broadcast_bucket."

File: roll/distributed/strategy/sglang_strategy.py (L262-266)

    def load_states(self, *args, **kwargs):if not self.model.is_model_in_gpu:self.model.resume_memory_occupation()logger.info("self.model.resume_memory_occupation exec ....")self.model.is_model_in_gpu = True

File: roll/third_party/vllm/worker_helper.py (L89-95)

    def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size):if src_pp_rank not in self.model_update_comm_plan:returncomm_plan = self.model_update_comm_plan[src_pp_rank]buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type)collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])WorkerHelper.update_parameter_in_bucket(self, meta_infos, buffer, [dist.get_rank()])

您在询问日志中 train_step_start_onload 的内存指标含义。

日志含义解释

这个日志来自 state_offload_manger 上下文管理器,记录了训练步骤开始时模型加载到 GPU 后的内存状态。 1

三个内存指标 2

  1. memory allocated (GB): 54.85

    • 当前实际分配给 PyTorch tensors 的 GPU 内存
    • 这是真正被模型参数、梯度、激活值占用的内存
    • 通过 torch.cuda.memory_allocated() 获取
  2. memory reserved (GB): 54.98

    • PyTorch 从 GPU 预留的总内存(包括缓存)
    • 这个值 ≥ memory allocated,因为包含了 PyTorch 的内存池缓存
    • 通过 torch.cuda.memory_reserved() 获取
  3. memory max reserved (GB): [未显示完整]

    • 自上次重置以来预留内存的峰值
    • 通过 torch.cuda.max_memory_reserved() 获取

在训练流程中的位置

这个日志出现在 ActorWorker.train_step() 方法中: 3

完整的内存记录时间点

state_offload_manger 在训练步骤中记录 4 个时间点: 4

  1. start_offload: 开始时(模型在 CPU) 5
  2. start_onload: 加载到 GPU 后 ← 您看到的日志 6
  3. end_onload: 训练完成时(模型仍在 GPU) 7
  4. end_offload: 卸载回 CPU 后 8

54.85 GB 包含什么

在训练步骤的 start_onload 阶段,GPU 内存包含: 9

根据 load_kwargs,此时加载了:

  • model_params: 模型权重(约 15-20 GB,取决于模型大小)
  • other_params: 其他参数(如 LayerNorm 的统计量)
  • optimizer_states: 优化器状态(Adam 的 momentum 和 variance,约为模型参数的 2-3 倍)

为什么是 54.85 GB

假设使用 7B 模型 + DeepSpeed ZeRO-3:

  • 模型参数(FP16): ~14 GB
  • 优化器状态(Adam,FP32): ~28 GB
  • 梯度缓存: ~14 GB
  • 其他开销(激活值缓存等): ~2 GB
  • 总计: ~58 GB

您的 54.85 GB 与此估算接近,说明模型和优化器状态已成功加载到 GPU。

与 actor_train 的对比

回顾您之前的日志:

actor_train/model_update_end_onload, memory allocated (GB): 15.697

为什么差异这么大? 10

model_update 时,只加载了 model_params(~15.7 GB),不包含优化器状态。

而在 train_step 时,加载了完整的训练状态(模型 + 优化器),因此内存占用更大(~54.85 GB)。

内存重置机制

在记录 start_onload 之前,会重置内存统计: 11

这确保 memory max reserved 反映的是当前训练步骤的峰值,而不是累积值。

Notes

  • train_step_start_onload 表示训练步骤开始时,模型和优化器已加载到 GPU
  • 54.85 GB 是正常的,包含模型参数、优化器状态和梯度缓存
  • 这个值比 model_update 时的 15.7 GB 大得多,因为优化器状态占用了额外的内存
  • reserved 略大于 allocated 是正常的,因为 PyTorch 会预留一些缓存以提高性能

Wiki pages you might want to explore:

  • State Offloading and Memory Management (alibaba/ROLL)

Wiki pages you might want to explore:

  • Pipeline Configuration (alibaba/ROLL)

Citations

File: roll/utils/context_managers.py (L25-39)

def log_gpu_memory_usage(head: str, logger: logging.Logger = None, rank: int = 0):global memory_log_print_limitsif memory_log_print_limits < 0:returnmemory_log_print_limits -= 1if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):memory_allocated = current_platform.memory_allocated() / 1024**3memory_reserved = current_platform.memory_reserved() / 1024**3memory_reserved_max = current_platform.max_memory_reserved() / 1024**3rss = cpu_memory_info().rss / 1024**3message = (f"{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}, "f"memory max reserved (GB): {memory_reserved_max}, rss (GB): {rss}")logger.info(msg=message)

File: roll/utils/context_managers.py (L148-186)

    with Timer(name=f"{metric_infix}_total") as timer, local_profiler():with Timer(name=f"{metric_infix}_onload") as onload_timer, profiling.profile("load_states"):for device_id in range(current_platform.device_count()):current_platform.reset_max_memory_allocated(device_id)current_platform.reset_max_memory_cached(device_id)current_platform.reset_peak_memory_stats(device_id)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/offload"))log_gpu_memory_usage(head=f"{metric_infix}_start_offload", logger=logger, rank=None)strategy.load_states(**load_kwargs)if load_kwargs.get("include", None) is not None:strategy.offload_states(**get_load_exclude_kwargs(load_kwargs))log_gpu_memory_usage(head=f"{metric_infix}_start_onload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/onload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "start"))with Timer(name=f"{metric_infix}_execute") as execute_timer, profiling.profile("execute"):yieldwith Timer(name=f"{metric_infix}_offload") as offload_timer, profiling.profile("offload_states"):metrics.update(_get_gpu_memory_metrics(metric_infix, "end/onload", with_max_frac=True))log_gpu_memory_usage(head=f"{metric_infix}_end_onload", logger=logger, rank=None)if is_offload_states:current_platform.clear_cublas_workspaces()strategy.offload_states()log_gpu_memory_usage(head=f"{metric_infix}_end_offload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "end/offload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "end"))metrics[f"time/{metric_infix}/total"] = timer.lastif is_roll_debug_mode():metrics[f"time/{metric_infix}/execute"] = execute_timer.lastmetrics[f"time/{metric_infix}/onload"] = onload_timer.lastmetrics[f"time/{metric_infix}/offload"] = offload_timer.lastdel os.environ["roll_EXEC_FUNC_NAME"]

File: roll/pipeline/base_worker.py (L79-85)

        with state_offload_manger(strategy=self.strategy,metrics=metrics,metric_infix=f"{self.cluster_name}/train_step",is_offload_states=is_offload_states,load_kwargs={"include": [OffloadStateType.model_params, OffloadStateType.other_params]},):

File: roll/distributed/executor/worker.py (L183-183)

                load_kwargs={"include": [OffloadStateType.model_params]},
http://www.dtcms.com/a/577539.html

相关文章:

  • QAbstractListModel 详细解析
  • 2025自动化运维厂商选型指南:数字化转型下,自动化运维平台为何成为“必选项”?
  • 如何把宏观战略转化为可执行的产品计划
  • 店铺设计素材针对网站做搜索引擎做优化
  • 温州网站排名优化公司哪家好网站推广服务合同模板
  • vscode-python学习-启动
  • STM32 串口线A-B
  • 使用 dnsmasq 搭建本地 DNS 服务器完整指南
  • 水墨画风格网站wordpress大气摄影主题
  • 详细介绍一下“集中同步+分布式入库”方案的具体实现步骤
  • 网站建设需要上传数据库吗双创网站建设
  • 轻量级Kafka集群管理工具
  • 嵌入式计算架构变革:ARM 浪潮下的替代革命与杰和科技产品布局
  • HarmonyOs鸿蒙开发,日期滑动选择器
  • 鸿蒙ArkUI布局与样式进阶(十六)——页面级变量、函数注入与 @BuilderParam 插槽机制全解析(附详细注释)
  • 网站加载页面怎么做seo网站设计外包
  • sqlserver2019中,一列为计算项目,一列为计算公式及计算项目代表的数字,如何找出一个计算项目是数字改变时,会有多个涉及的计算项目
  • 网站截图可以做证据吗微信小程序模板免费下载
  • 手机兼容测试服务提供商对比分析:腾讯优测Utest的优势与挑战
  • repo xml语法
  • 如何选择能够高效运行的云手机
  • IFC转换为3DXML的技术指南在线转换推荐
  • 站长之家工具网页界面设计的内容五大设计要素
  • MAUI劝退:内部消息机制(社区工具包)
  • 西安 网站 公司wordpress同步微信公众号
  • Xshell效率实战:SSH管理秘籍技术大纲
  • 和平精英java 游戏程序
  • 【Java】异常详解+实例演示+知识总结
  • 【大模型训练】sglang 权重绑定和roll HF Meg相互转化
  • 有那个网站可以做报名链接的网站开发项目简单描述