【大模型训练】 roll 权重更新 过程
您在询问为什么日志中显示的是 model_update_end_onload 而不是 model_update_end_offload。
日志含义解释
这个日志来自 state_offload_manger 上下文管理器,它记录了 model_update 过程中的内存状态。 1
日志时间点
state_offload_manger 在三个关键时间点记录内存:
start_offload: 开始时的状态(模型在 CPU) 2start_onload: 加载到 GPU 后的状态 3end_onload: 执行完成时的状态(模型仍在 GPU) 4end_offload: 卸载回 CPU 后的状态 5
Model Update 的特殊性
在 Worker.start_model_update() 中,调用 state_offload_manger 时使用了特殊的 load_kwargs: 6
关键配置:load_kwargs={"include": [OffloadStateType.model_params]}
这意味着:
- 只加载
model_params(模型权重)到 GPU - 执行 model_update(收集权重并广播) 7
- 执行完成后,模型权重仍在 GPU 上(因为
state_offload_manger默认在 yield 后才 offload)
为什么是 end_onload
日志显示 model_update_end_onload 是因为:
- 这是在
yield语句之后、offload_states()调用之前记录的 8 - 此时模型权重仍然在 GPU 上(onload 状态)
- 内存占用 15.7GB 正是模型权重的大小
完整流程
后续的 Offload
虽然日志显示 end_onload,但 state_offload_manger 会在之后立即执行 offload: 9
只是这个 end_offload 的日志可能在您提供的日志片段中没有显示出来。
Notes
model_update_end_onload表示 model_update 执行完成时,模型权重仍在 GPU 上- 这是正常的,因为
state_offload_manger先记录end_onload,然后才执行 offload - 15.7GB 的内存占用证明模型权重已成功加载到 GPU
- 之后会有
end_offload日志(内存应该降到接近 0),只是您的日志片段中没有包含
Wiki pages you might want to explore:
- State Offloading and Memory Management (alibaba/ROLL)
Wiki pages you might want to explore:
- Pipeline Configuration (alibaba/ROLL)
- State Offloading and Memory Management (alibaba/ROLL)
Citations
File: roll/utils/context_managers.py (L140-186)
@contextmanager
def state_offload_manger(strategy, metrics: Dict, metric_infix: str, is_offload_states=True, load_kwargs={}):"""strategy.load_states()strategy.offload_states()为metrics埋点"""os.environ["roll_EXEC_FUNC_NAME"] = metric_infixwith Timer(name=f"{metric_infix}_total") as timer, local_profiler():with Timer(name=f"{metric_infix}_onload") as onload_timer, profiling.profile("load_states"):for device_id in range(current_platform.device_count()):current_platform.reset_max_memory_allocated(device_id)current_platform.reset_max_memory_cached(device_id)current_platform.reset_peak_memory_stats(device_id)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/offload"))log_gpu_memory_usage(head=f"{metric_infix}_start_offload", logger=logger, rank=None)strategy.load_states(**load_kwargs)if load_kwargs.get("include", None) is not None:strategy.offload_states(**get_load_exclude_kwargs(load_kwargs))log_gpu_memory_usage(head=f"{metric_infix}_start_onload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/onload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "start"))with Timer(name=f"{metric_infix}_execute") as execute_timer, profiling.profile("execute"):yieldwith Timer(name=f"{metric_infix}_offload") as offload_timer, profiling.profile("offload_states"):metrics.update(_get_gpu_memory_metrics(metric_infix, "end/onload", with_max_frac=True))log_gpu_memory_usage(head=f"{metric_infix}_end_onload", logger=logger, rank=None)if is_offload_states:current_platform.clear_cublas_workspaces()strategy.offload_states()log_gpu_memory_usage(head=f"{metric_infix}_end_offload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "end/offload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "end"))metrics[f"time/{metric_infix}/total"] = timer.lastif is_roll_debug_mode():metrics[f"time/{metric_infix}/execute"] = execute_timer.lastmetrics[f"time/{metric_infix}/onload"] = onload_timer.lastmetrics[f"time/{metric_infix}/offload"] = offload_timer.lastdel os.environ["roll_EXEC_FUNC_NAME"]
File: roll/distributed/executor/worker.py (L176-190)
def start_model_update(self, *args, **kwargs):metrics = {}if getattr(self, "strategy", None) is not None:with state_offload_manger(strategy=self.strategy,metrics=metrics,metric_infix=f"{self.cluster_name}/model_update",load_kwargs={"include": [OffloadStateType.model_params]},):exec_metrics: Dict = self.strategy.model_update(*args, **kwargs)metric_prefix = f"time/{self.cluster_name}/model_update"metrics.update({f"{metric_prefix}/{k}": v for k, v in exec_metrics.items()})else:self.logger.warning("worker has not strategy")
File: roll/distributed/strategy/megatron_strategy.py (L494-534)
def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):comm_plan = self.model_update_comm_plan[model_update_name][self.worker.rank_info.pp_rank]broadcast_time_cost = 0with Timer("model_update_total") as timer_total:for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024):ray.get(self.barrier.wait.remote())refs = []with Timer("broadcast") as timer_broadcast:for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(model_update_name=model_update_name,meta_infos=meta_infos, buffer=buffer, ranks_in_worker=[p2p_tgt_device["device"]["rank"]])refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,meta_infos=meta_infos,bucket_size=buffer.numel() * buffer.element_size(),)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)ray.get(self.barrier.wait.remote())broadcast_time_cost += timer_broadcast.lastmetrics = {"all_gather": timer_total.last - broadcast_time_cost,"broadcast": broadcast_time_cost,}return metrics
您在询问权重更新过程中 actor_infer 接收端的状态变化。
Actor Infer 的状态变化流程
根据您的日志和代码,actor_infer 使用 SGLang 策略,在权重更新过程中经历以下状态:
1. 初始状态(Offload)
在生成开始前,actor_infer 处于 offload 状态: 1
actor_infer/generate_start_offload, memory allocated (GB): 0.0
此时模型权重不在 GPU 上(使用 load_format="dummy" 初始化)。
2. 接收权重(Model Update)
当 actor_train 执行 model_update 时,会通过以下路径将权重发送到 actor_infer:
发送端(actor_train): 2
- 收集权重到 bucket(
all_gather_weights_as_hf_bucket) - 通过 P2P 或 broadcast 发送到目标 worker
接收端(actor_infer with SGLang): 3
SGLang 的 broadcast_bucket 方法会:
- 创建空 buffer 接收数据 4
- 通过
collective.broadcast接收权重 5 - 调用
update_parameter_in_bucket更新模型参数 6
3. 加载权重到 GPU(Load States)
权重接收后,在生成开始时调用 load_states(): 7
这会触发 SGLang 的 resume_memory_occupation(),将模型加载到 GPU。
4. 生成状态(Onload)
模型加载完成后,actor_infer 进入 onload 状态:
actor_infer/generate_start_onload, memory allocated (GB): 应该显示模型大小
但您的日志显示仍为 0.0GB,说明权重没有成功加载!
问题诊断
为什么 actor_infer 内存仍为 0?
从您的日志 Broadcast time: 0.00s 可以看出,broadcast 没有执行。
可能的原因:
-
通信组未正确设置: 8
检查日志中是否有
"warmup setup_collective_group"消息。如果没有,说明 SGLang 没有加入通信组。 -
Broadcast 条件不满足: 9
只有
tp_rank == 0 && cp_rank == 0 && dp_rank == 0的 rank 才会触发 broadcast。 -
broadcast_tgt_devices为空: 10如果配置中没有指定 broadcast 目标设备,就不会执行广播。
正常的状态变化应该是
解决建议
-
检查初始化日志:
- 搜索
"setup_collective_group"确认通信组是否建立 - 搜索
"warmup setup_collective_group"确认 SGLang 是否加入
- 搜索
-
验证通信计划配置:
确认 pipeline 配置中model_update_comm_plan包含了actor_infer的设备信息 -
临时解决方案:
使用load_format: auto让actor_infer直接加载模型,绕过 model_update:actor_infer:strategy_args:strategy_config:load_format: auto
Notes
- SGLang 的权重更新依赖于正确的通信组设置 11
load_states()只负责将已接收的权重加载到 GPU,不负责接收权重本身 7- 如果
Broadcast time: 0.00s,说明权重根本没有发送,actor_infer自然无法接收 - vLLM 和 SGLang 的权重更新机制类似,都通过
WorkerHelper实现 12
Wiki pages you might want to explore:
- Strategy Pattern and Backend Abstraction (alibaba/ROLL)
- State Offloading and Memory Management (alibaba/ROLL)
Wiki pages you might want to explore:
- Strategy Pattern and Backend Abstraction (alibaba/ROLL)
- State Offloading and Memory Management (alibaba/ROLL)
Citations
File: roll/pipeline/rlvr/rlvr_pipeline.py (L370-370)
self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info))
File: roll/distributed/strategy/megatron_strategy.py (L498-528)
for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024):ray.get(self.barrier.wait.remote())refs = []with Timer("broadcast") as timer_broadcast:for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(model_update_name=model_update_name,meta_infos=meta_infos, buffer=buffer, ranks_in_worker=[p2p_tgt_device["device"]["rank"]])refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,meta_infos=meta_infos,bucket_size=buffer.numel() * buffer.element_size(),)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)ray.get(self.barrier.wait.remote())broadcast_time_cost += timer_broadcast.last
File: roll/third_party/sglang/v052_patch/model_runner.py (L138-162)
def setup_collective_group(self, comm_plan, backend, rank_in_cluster):self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {})rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster,rank_in_worker=dist.get_rank())if rank is None:logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}")return True, "Succeeded to setup_collective_group."group_name = comm_plan_args["group_name"]master_addr = comm_plan_args["master_addr"]master_port = comm_plan_args["master_port"]world_size = len(comm_plan_args["tgt_devices"]) + 1src_pp_rank = comm_plan_args["src_pp_rank"]collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name,master_addr=master_addr, master_port=master_port)# A small all_reduce for warmup.collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)self.model_update_comm_plan[src_pp_rank] = dict(rank=rank,world_size=world_size,src_pp_rank=src_pp_rank,group_name=group_name,comm_plan=comm_plan,comm_plan_args=comm_plan_args)logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")return True, "Succeeded to setup_collective_group."
File: roll/third_party/sglang/v052_patch/model_runner.py (L164-172)
def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size):if src_pp_rank not in self.model_update_comm_plan:return True, "Succeeded to broadcast_bucket."comm_plan = self.model_update_comm_plan[src_pp_rank]buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type)collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])self.update_parameter_in_bucket(meta_infos, buffer, [dist.get_rank()])return True, "Succeeded to broadcast_bucket."
File: roll/distributed/strategy/sglang_strategy.py (L262-266)
def load_states(self, *args, **kwargs):if not self.model.is_model_in_gpu:self.model.resume_memory_occupation()logger.info("self.model.resume_memory_occupation exec ....")self.model.is_model_in_gpu = True
File: roll/third_party/vllm/worker_helper.py (L89-95)
def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size):if src_pp_rank not in self.model_update_comm_plan:returncomm_plan = self.model_update_comm_plan[src_pp_rank]buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type)collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])WorkerHelper.update_parameter_in_bucket(self, meta_infos, buffer, [dist.get_rank()])
您在询问日志中 train_step_start_onload 的内存指标含义。
日志含义解释
这个日志来自 state_offload_manger 上下文管理器,记录了训练步骤开始时模型加载到 GPU 后的内存状态。 1
三个内存指标 2
-
memory allocated (GB): 54.85
- 当前实际分配给 PyTorch tensors 的 GPU 内存
- 这是真正被模型参数、梯度、激活值占用的内存
- 通过
torch.cuda.memory_allocated()获取
-
memory reserved (GB): 54.98
- PyTorch 从 GPU 预留的总内存(包括缓存)
- 这个值 ≥ memory allocated,因为包含了 PyTorch 的内存池缓存
- 通过
torch.cuda.memory_reserved()获取
-
memory max reserved (GB): [未显示完整]
- 自上次重置以来预留内存的峰值
- 通过
torch.cuda.max_memory_reserved()获取
在训练流程中的位置
这个日志出现在 ActorWorker.train_step() 方法中: 3
完整的内存记录时间点
state_offload_manger 在训练步骤中记录 4 个时间点: 4
- start_offload: 开始时(模型在 CPU) 5
- start_onload: 加载到 GPU 后 ← 您看到的日志 6
- end_onload: 训练完成时(模型仍在 GPU) 7
- end_offload: 卸载回 CPU 后 8
54.85 GB 包含什么
在训练步骤的 start_onload 阶段,GPU 内存包含: 9
根据 load_kwargs,此时加载了:
- model_params: 模型权重(约 15-20 GB,取决于模型大小)
- other_params: 其他参数(如 LayerNorm 的统计量)
- optimizer_states: 优化器状态(Adam 的 momentum 和 variance,约为模型参数的 2-3 倍)
为什么是 54.85 GB
假设使用 7B 模型 + DeepSpeed ZeRO-3:
- 模型参数(FP16): ~14 GB
- 优化器状态(Adam,FP32): ~28 GB
- 梯度缓存: ~14 GB
- 其他开销(激活值缓存等): ~2 GB
- 总计: ~58 GB
您的 54.85 GB 与此估算接近,说明模型和优化器状态已成功加载到 GPU。
与 actor_train 的对比
回顾您之前的日志:
actor_train/model_update_end_onload, memory allocated (GB): 15.697
为什么差异这么大? 10
在 model_update 时,只加载了 model_params(~15.7 GB),不包含优化器状态。
而在 train_step 时,加载了完整的训练状态(模型 + 优化器),因此内存占用更大(~54.85 GB)。
内存重置机制
在记录 start_onload 之前,会重置内存统计: 11
这确保 memory max reserved 反映的是当前训练步骤的峰值,而不是累积值。
Notes
train_step_start_onload表示训练步骤开始时,模型和优化器已加载到 GPU- 54.85 GB 是正常的,包含模型参数、优化器状态和梯度缓存
- 这个值比
model_update时的 15.7 GB 大得多,因为优化器状态占用了额外的内存 reserved略大于allocated是正常的,因为 PyTorch 会预留一些缓存以提高性能
Wiki pages you might want to explore:
- State Offloading and Memory Management (alibaba/ROLL)
Wiki pages you might want to explore:
- Pipeline Configuration (alibaba/ROLL)
Citations
File: roll/utils/context_managers.py (L25-39)
def log_gpu_memory_usage(head: str, logger: logging.Logger = None, rank: int = 0):global memory_log_print_limitsif memory_log_print_limits < 0:returnmemory_log_print_limits -= 1if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):memory_allocated = current_platform.memory_allocated() / 1024**3memory_reserved = current_platform.memory_reserved() / 1024**3memory_reserved_max = current_platform.max_memory_reserved() / 1024**3rss = cpu_memory_info().rss / 1024**3message = (f"{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}, "f"memory max reserved (GB): {memory_reserved_max}, rss (GB): {rss}")logger.info(msg=message)
File: roll/utils/context_managers.py (L148-186)
with Timer(name=f"{metric_infix}_total") as timer, local_profiler():with Timer(name=f"{metric_infix}_onload") as onload_timer, profiling.profile("load_states"):for device_id in range(current_platform.device_count()):current_platform.reset_max_memory_allocated(device_id)current_platform.reset_max_memory_cached(device_id)current_platform.reset_peak_memory_stats(device_id)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/offload"))log_gpu_memory_usage(head=f"{metric_infix}_start_offload", logger=logger, rank=None)strategy.load_states(**load_kwargs)if load_kwargs.get("include", None) is not None:strategy.offload_states(**get_load_exclude_kwargs(load_kwargs))log_gpu_memory_usage(head=f"{metric_infix}_start_onload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "start/onload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "start"))with Timer(name=f"{metric_infix}_execute") as execute_timer, profiling.profile("execute"):yieldwith Timer(name=f"{metric_infix}_offload") as offload_timer, profiling.profile("offload_states"):metrics.update(_get_gpu_memory_metrics(metric_infix, "end/onload", with_max_frac=True))log_gpu_memory_usage(head=f"{metric_infix}_end_onload", logger=logger, rank=None)if is_offload_states:current_platform.clear_cublas_workspaces()strategy.offload_states()log_gpu_memory_usage(head=f"{metric_infix}_end_offload", logger=logger, rank=None)metrics.update(_get_gpu_memory_metrics(metric_infix, "end/offload"))metrics.update(_get_cpu_memory_metrics(metric_infix, "end"))metrics[f"time/{metric_infix}/total"] = timer.lastif is_roll_debug_mode():metrics[f"time/{metric_infix}/execute"] = execute_timer.lastmetrics[f"time/{metric_infix}/onload"] = onload_timer.lastmetrics[f"time/{metric_infix}/offload"] = offload_timer.lastdel os.environ["roll_EXEC_FUNC_NAME"]
File: roll/pipeline/base_worker.py (L79-85)
with state_offload_manger(strategy=self.strategy,metrics=metrics,metric_infix=f"{self.cluster_name}/train_step",is_offload_states=is_offload_states,load_kwargs={"include": [OffloadStateType.model_params, OffloadStateType.other_params]},):
File: roll/distributed/executor/worker.py (L183-183)
load_kwargs={"include": [OffloadStateType.model_params]},
