[大模型训练】权重更新
您询问的是在 ROLL 框架中,当发送端(actor_train)和接收端(actor_infer)都有 DP(Data Parallel)和 PP(Pipeline Parallel)分组时,权重是如何从 actor_train 发送到 actor_infer,以及如何找到对应的 worker 的。
问题理解
在分布式训练场景中:
- actor_train: 使用 DP + PP 进行训练,例如 8 个 workers 可能配置为 DP=2, PP=4
- actor_infer: 也可能使用 DP + PP 进行推理,例如 4 个 workers 配置为 DP=2, PP=2
关键问题是:如何建立正确的通信映射,确保每个 actor_train worker 知道应该向哪些 actor_infer workers 发送权重。
通信计划(Communication Plan)的建立
1. 在 Pipeline 初始化时创建通信计划
在 BasePipeline.set_model_update_pair 方法中,会创建从 actor_train 到 actor_infer 的通信计划: 1
这个方法会调用底层的通信计划生成逻辑,根据源集群和目标集群的并行配置,计算出每个 rank 应该与哪些 rank 通信。
2. 通信计划的结构
通信计划是一个嵌套字典,结构如下:
comm_plan = {src_pp_rank: { # 源端的 pipeline parallel rank"group_name": "model_update_actor_train_0_to_actor_infer","master_addr": "10.0.0.1","master_port": "29500","tgt_devices": [ # 目标设备列表{"rank": 0, "device": {"rank": 0, "node_rank": 0, "gpu_rank": 0}},{"rank": 1, "device": {"rank": 0, "node_rank": 0, "gpu_rank": 0}},],"src_pp_rank": 0,}
}
```[2](#50-1) ### 3. 关键映射逻辑通信计划的生成遵循以下原则:**Pipeline Parallel 映射**:
- 每个 actor_train 的 PP rank 负责向对应的 actor_infer PP rank 发送权重
- 例如:actor_train PP rank 0 → actor_infer PP rank 0**Data Parallel 优化**:
- 在 DP 维度上,**通常只有一个 DP rank(通常是 rank 0)负责广播权重**
- 其他 DP ranks 不参与权重传输,因为它们的权重是相同的**Tensor Parallel 处理**:
- 如果 actor_train 和 actor_infer 的 TP size 相同,每个 TP rank 发送自己的权重分片
- 如果 TP size 不同,需要先 all-gather 权重,然后重新切分## Megatron Strategy 的实现在 Megatron 策略中,`model_update` 方法展示了完整的权重发送逻辑:<cite /> [3](#50-2) ### 关键步骤解析**步骤 1: 收集权重**
```python
for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024
):
- 在 DP/TP/PP 维度上 all-gather 权重,转换为 HF 格式的 bucket
- 每个 bucket 约 256MB,包含多个参数
步骤 2: 同步屏障
ray.get(self.barrier.wait.remote())
- 确保所有 actor_train ranks 都准备好发送
步骤 3: P2P 更新(点对点)
for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(...)
- 对于特定的目标设备,直接通过 Ray RPC 发送权重
p2p_tgt_devices是通信计划中指定的需要点对点传输的设备
步骤 4: Broadcast 更新(广播)
if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0
):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(...)
- 只有 TP rank 0, CP rank 0, DP rank 0 的 worker 负责广播
- 向所有目标 workers 发送 broadcast 命令
步骤 5: Collective Broadcast
if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])
- 使用 NCCL collective broadcast 在通信组上广播权重
broadcast_tgt_devices是通信计划中指定的需要广播的设备列表
接收端的处理
在 actor_infer 端,权重接收通过以下方法:
1. Broadcast Bucket 接收 4
工作流程:
- 检查
src_pp_rank是否在model_update_comm_plan中 - 如果不在,说明当前 worker 不需要从这个 PP rank 接收权重,直接返回
- 如果在,创建空 buffer,通过
collective.broadcast接收数据 - 调用
update_parameter_in_bucket将权重加载到模型
2. 通信组的建立
在接收权重之前,actor_infer 必须先建立通信组: 5
关键点:
- 通过
get_dist_info_from_comm_plan从通信计划中提取当前 worker 的 rank 和 world_size - 如果
rank is None,说明当前 worker 不参与这个通信组,直接返回 - 使用
collective.init_collective_group初始化 NCCL 通信组 - 执行 warmup all-reduce 验证通信组可用
具体示例:8 卡训练 → 1 卡推理
假设配置:
- actor_train: 8 个 workers,DP=2, PP=4 (每个 PP stage 有 2 个 DP replicas)
- actor_infer: 1 个 worker,DP=1, PP=1
通信计划生成
系统会生成如下通信计划:
comm_plan = {0: { # actor_train PP rank 0"tgt_devices": [{"rank": 0, "device": {"rank": 0}}], # actor_infer rank 0},1: { # actor_train PP rank 1"tgt_devices": [], # 不发送,因为 actor_infer 只有 1 个 PP stage},2: { # actor_train PP rank 2"tgt_devices": [],},3: { # actor_train PP rank 3"tgt_devices": [],},
}
权重发送流程
-
actor_train PP rank 0, DP rank 0 负责发送:
- 收集 PP stage 0 的权重(通过 DP all-gather)
- 通过 broadcast 发送到 actor_infer rank 0
-
actor_train 其他 ranks 不参与发送:
- PP rank 1-3 的
tgt_devices为空,跳过发送 - DP rank 1 的 workers 也不发送(只有 DP rank 0 发送)
- PP rank 1-3 的
-
actor_infer rank 0 接收:
- 只从 actor_train PP rank 0 接收权重
- 因为 actor_infer 只有 1 个 PP stage,所以只需要 PP rank 0 的权重
DeepSpeed Strategy 的实现
DeepSpeed 策略使用类似的逻辑,但逐参数传输: 6
关键区别:
- 不使用 bucket,而是逐个参数传输
- 对于 ZeRO-3,需要先用
GatheredParameters收集分片参数 - 同样只有 TP rank 0, CP rank 0, DP rank 0 负责广播
如何找到对应的 Worker
1. 通过 Ray Actor Handle
在 model_update 方法中,tgt_workers 是一个 Ray actor handle 列表:
def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(...)
这些 handles 是在 pipeline 初始化时通过 set_model_update_pair 传递的。
2. 通过 Rank 映射
在 P2P 更新中,通过 p2p_tgt_device["rank"] 索引找到对应的 worker:
for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(...)
这里的 rank 是 actor_infer 集群内的 rank,不是全局 rank。
3. 通过通信组名称
在 collective broadcast 中,通过 group_name 找到对应的通信组:
collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])
通信组在 setup_collective_group 时已经建立,包含了正确的 rank 映射。
Notes
ROLL 框架的权重同步机制通过**通信计划(comm_plan)**实现了灵活的 rank 映射,支持任意的 DP/PP/TP 配置组合。关键设计原则是:
- 只有必要的 ranks 参与传输:通常只有 DP rank 0, TP rank 0, CP rank 0 负责发送
- PP 维度一对一映射:每个 PP stage 对应发送到相应的 PP stage
- 混合使用 P2P 和 Broadcast:P2P 用于特定设备,Broadcast 用于批量传输
- 通过通信组隔离:不同的 PP ranks 使用不同的通信组,避免干扰
这种设计使得系统可以高效地在不同并行配置的集群之间同步权重,同时最小化通信开销。
Wiki pages you might want to explore:
- Parameter Synchronization (alibaba/ROLL)
- Strategy Pattern and Backend Abstraction (alibaba/ROLL)
Wiki pages you might want to explore:
- Strategy Pattern and Backend Abstraction (alibaba/ROLL)
- State Offloading and Memory Management (alibaba/ROLL)
Citations
File: roll/pipeline/rlvr/rlvr_pipeline.py (L309-313)
self.set_model_update_pair(src_cluster=self.actor_train,tgt_cluster=self.actor_infer,frequency=self.pipeline_config.actor_train.model_update_frequency,)
File: roll/distributed/strategy/strategy.py (L85-118)
def setup_collective_group(self, model_update_name, comm_plan, backend=None):"""单卡infer strategy可直接复用,多卡infer strategy需要自行管理"""if backend is None:backend = current_platform.communication_backendrank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=self.worker.rank, rank_in_worker=0)if rank is None:logger.info(f"no comm_plan found for rank {self.worker.rank}/{0}")returngroup_name = comm_plan_args["group_name"]master_addr = comm_plan_args["master_addr"]master_port = comm_plan_args["master_port"]world_size = len(comm_plan_args["tgt_devices"]) + 1src_pp_rank = comm_plan_args["src_pp_rank"]logger.info(f"{group_name} rank: {rank} world_size: {world_size}, {comm_plan_args}")collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name, master_addr=master_addr, master_port=master_port)# A small all_reduce for warmup.collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)if model_update_name not in self.model_update_comm_plan:self.model_update_comm_plan[model_update_name] = {}self.model_update_comm_plan[model_update_name][src_pp_rank] = dict(rank=rank,world_size=world_size,src_pp_rank=src_pp_rank,group_name=group_name,comm_plan=comm_plan,comm_plan_args=comm_plan_args,)logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")
File: roll/distributed/strategy/megatron_strategy.py (L494-534)
def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):comm_plan = self.model_update_comm_plan[model_update_name][self.worker.rank_info.pp_rank]broadcast_time_cost = 0with Timer("model_update_total") as timer_total:for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024):ray.get(self.barrier.wait.remote())refs = []with Timer("broadcast") as timer_broadcast:for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(model_update_name=model_update_name,meta_infos=meta_infos, buffer=buffer, ranks_in_worker=[p2p_tgt_device["device"]["rank"]])refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,meta_infos=meta_infos,bucket_size=buffer.numel() * buffer.element_size(),)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)ray.get(self.barrier.wait.remote())broadcast_time_cost += timer_broadcast.lastmetrics = {"all_gather": timer_total.last - broadcast_time_cost,"broadcast": broadcast_time_cost,}return metrics
File: roll/third_party/vllm/worker_helper.py (L65-87)
def setup_collective_group(self, comm_plan, backend, rank_in_cluster):self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {})rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster,rank_in_worker=dist.get_rank())if rank is None:logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}")returngroup_name = comm_plan_args["group_name"]master_addr = comm_plan_args["master_addr"]master_port = comm_plan_args["master_port"]world_size = len(comm_plan_args["tgt_devices"]) + 1src_pp_rank = comm_plan_args["src_pp_rank"]collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name,master_addr=master_addr, master_port=master_port)# A small all_reduce for warmup.collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)self.model_update_comm_plan[src_pp_rank] = dict(rank=rank,world_size=world_size,src_pp_rank=src_pp_rank,group_name=group_name,comm_plan=comm_plan,comm_plan_args=comm_plan_args)logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")
File: roll/third_party/vllm/worker_helper.py (L89-95)
def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size):if src_pp_rank not in self.model_update_comm_plan:returncomm_plan = self.model_update_comm_plan[src_pp_rank]buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type)collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])WorkerHelper.update_parameter_in_bucket(self, meta_infos, buffer, [dist.get_rank()])
File: roll/distributed/strategy/deepspeed_strategy.py (L526-633)
def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):model = self.unwrap_model()if is_lora := (self.worker_config.model_args.lora_target is not None):all_params = self.collect_lora_params()peft_config = model.peft_config.get("default", None)else:all_params = list(model.named_parameters())comm_plan = self.model_update_comm_plan[model_update_name][self.worker.rank_info.pp_rank]model = self.unwrap_model()broadcast_time_cost = 0with Timer("model_update_total") as timer_total:for param_name, param in tqdm(all_params, desc="weight update progress", total=len(all_params)):shape = param.shape if not self.ds_config.is_zero3() else param.ds_shapeif not self.ds_config.is_zero3():param_weight = param.datarefs = []for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter.remote(model_update_name=model_update_name,parameter_name=param_name,weight=param_weight,ranks_in_worker=[p2p_tgt_device["device"]["rank"]],is_lora=is_lora,)refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_parameter.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,dtype=param_weight.dtype,shape=shape,parameter_name=param_name,is_lora=is_lora,)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=param_weight, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)else:with GatheredParameters([param]):param_weight = param.datawith Timer("broadcast") as timer_broadcast:refs = []for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter.remote(model_update_name=model_update_name,parameter_name=param_name,weight=param_weight,ranks_in_worker=[p2p_tgt_device["device"]["rank"]],is_lora=is_lora,)refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_parameter.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,dtype=param_weight.dtype,shape=shape,parameter_name=param_name,is_lora=is_lora,)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=param_weight, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)broadcast_time_cost += timer_broadcast.lastif is_lora:with Timer("add_lora") as timer_add_lora:if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):refs = []for worker in tgt_workers:ref = worker.add_lora.remote(peft_config=asdict(peft_config))refs.append(ref)ray.get(refs)metrics = {"broadcast": broadcast_time_cost,}if is_lora:metrics["all_gather"] = timer_total.last - broadcast_time_cost - timer_add_lora.lastmetrics["add_lora"] = timer_add_lora.lastelse:metrics["all_gather"] = timer_total.last - broadcast_time_costreturn metrics
