当前位置: 首页 > news >正文

[大模型训练】权重更新

您询问的是在 ROLL 框架中,当发送端(actor_train)和接收端(actor_infer)都有 DP(Data Parallel)和 PP(Pipeline Parallel)分组时,权重是如何从 actor_train 发送到 actor_infer,以及如何找到对应的 worker 的

问题理解

在分布式训练场景中:

  • actor_train: 使用 DP + PP 进行训练,例如 8 个 workers 可能配置为 DP=2, PP=4
  • actor_infer: 也可能使用 DP + PP 进行推理,例如 4 个 workers 配置为 DP=2, PP=2

关键问题是:如何建立正确的通信映射,确保每个 actor_train worker 知道应该向哪些 actor_infer workers 发送权重

通信计划(Communication Plan)的建立

1. 在 Pipeline 初始化时创建通信计划

BasePipeline.set_model_update_pair 方法中,会创建从 actor_train 到 actor_infer 的通信计划: 1

这个方法会调用底层的通信计划生成逻辑,根据源集群和目标集群的并行配置,计算出每个 rank 应该与哪些 rank 通信。

2. 通信计划的结构

通信计划是一个嵌套字典,结构如下:

comm_plan = {src_pp_rank: {  # 源端的 pipeline parallel rank"group_name": "model_update_actor_train_0_to_actor_infer","master_addr": "10.0.0.1","master_port": "29500","tgt_devices": [  # 目标设备列表{"rank": 0, "device": {"rank": 0, "node_rank": 0, "gpu_rank": 0}},{"rank": 1, "device": {"rank": 0, "node_rank": 0, "gpu_rank": 0}},],"src_pp_rank": 0,}
}
```[2](#50-1) ### 3. 关键映射逻辑通信计划的生成遵循以下原则:**Pipeline Parallel 映射**:
- 每个 actor_train 的 PP rank 负责向对应的 actor_infer PP rank 发送权重
- 例如:actor_train PP rank 0 → actor_infer PP rank 0**Data Parallel 优化**:
- 在 DP 维度上,**通常只有一个 DP rank(通常是 rank 0)负责广播权重**
- 其他 DP ranks 不参与权重传输,因为它们的权重是相同的**Tensor Parallel 处理**:
- 如果 actor_train 和 actor_infer 的 TP size 相同,每个 TP rank 发送自己的权重分片
- 如果 TP size 不同,需要先 all-gather 权重,然后重新切分## Megatron Strategy 的实现在 Megatron 策略中,`model_update` 方法展示了完整的权重发送逻辑:<cite /> [3](#50-2) ### 关键步骤解析**步骤 1: 收集权重**
```python
for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024
):
  • 在 DP/TP/PP 维度上 all-gather 权重,转换为 HF 格式的 bucket
  • 每个 bucket 约 256MB,包含多个参数

步骤 2: 同步屏障

ray.get(self.barrier.wait.remote())
  • 确保所有 actor_train ranks 都准备好发送

步骤 3: P2P 更新(点对点)

for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(...)
  • 对于特定的目标设备,直接通过 Ray RPC 发送权重
  • p2p_tgt_devices 是通信计划中指定的需要点对点传输的设备

步骤 4: Broadcast 更新(广播)

if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0
):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(...)
  • 只有 TP rank 0, CP rank 0, DP rank 0 的 worker 负责广播
  • 向所有目标 workers 发送 broadcast 命令

步骤 5: Collective Broadcast

if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])
  • 使用 NCCL collective broadcast 在通信组上广播权重
  • broadcast_tgt_devices 是通信计划中指定的需要广播的设备列表

接收端的处理

在 actor_infer 端,权重接收通过以下方法:

1. Broadcast Bucket 接收 4

工作流程:

  1. 检查 src_pp_rank 是否在 model_update_comm_plan
  2. 如果不在,说明当前 worker 不需要从这个 PP rank 接收权重,直接返回
  3. 如果在,创建空 buffer,通过 collective.broadcast 接收数据
  4. 调用 update_parameter_in_bucket 将权重加载到模型

2. 通信组的建立

在接收权重之前,actor_infer 必须先建立通信组: 5

关键点:

  • 通过 get_dist_info_from_comm_plan 从通信计划中提取当前 worker 的 rank 和 world_size
  • 如果 rank is None,说明当前 worker 不参与这个通信组,直接返回
  • 使用 collective.init_collective_group 初始化 NCCL 通信组
  • 执行 warmup all-reduce 验证通信组可用

具体示例:8 卡训练 → 1 卡推理

假设配置:

  • actor_train: 8 个 workers,DP=2, PP=4 (每个 PP stage 有 2 个 DP replicas)
  • actor_infer: 1 个 worker,DP=1, PP=1

通信计划生成

系统会生成如下通信计划:

comm_plan = {0: {  # actor_train PP rank 0"tgt_devices": [{"rank": 0, "device": {"rank": 0}}],  # actor_infer rank 0},1: {  # actor_train PP rank 1"tgt_devices": [],  # 不发送,因为 actor_infer 只有 1 个 PP stage},2: {  # actor_train PP rank 2"tgt_devices": [],},3: {  # actor_train PP rank 3"tgt_devices": [],},
}

权重发送流程

  1. actor_train PP rank 0, DP rank 0 负责发送:

    • 收集 PP stage 0 的权重(通过 DP all-gather)
    • 通过 broadcast 发送到 actor_infer rank 0
  2. actor_train 其他 ranks 不参与发送:

    • PP rank 1-3 的 tgt_devices 为空,跳过发送
    • DP rank 1 的 workers 也不发送(只有 DP rank 0 发送)
  3. actor_infer rank 0 接收:

    • 只从 actor_train PP rank 0 接收权重
    • 因为 actor_infer 只有 1 个 PP stage,所以只需要 PP rank 0 的权重

DeepSpeed Strategy 的实现

DeepSpeed 策略使用类似的逻辑,但逐参数传输: 6

关键区别:

  • 不使用 bucket,而是逐个参数传输
  • 对于 ZeRO-3,需要先用 GatheredParameters 收集分片参数
  • 同样只有 TP rank 0, CP rank 0, DP rank 0 负责广播

如何找到对应的 Worker

1. 通过 Ray Actor Handle

model_update 方法中,tgt_workers 是一个 Ray actor handle 列表:

def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(...)

这些 handles 是在 pipeline 初始化时通过 set_model_update_pair 传递的。

2. 通过 Rank 映射

在 P2P 更新中,通过 p2p_tgt_device["rank"] 索引找到对应的 worker:

for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(...)

这里的 rank 是 actor_infer 集群内的 rank,不是全局 rank。

3. 通过通信组名称

在 collective broadcast 中,通过 group_name 找到对应的通信组:

collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])

通信组在 setup_collective_group 时已经建立,包含了正确的 rank 映射。

Notes

ROLL 框架的权重同步机制通过**通信计划(comm_plan)**实现了灵活的 rank 映射,支持任意的 DP/PP/TP 配置组合。关键设计原则是:

  1. 只有必要的 ranks 参与传输:通常只有 DP rank 0, TP rank 0, CP rank 0 负责发送
  2. PP 维度一对一映射:每个 PP stage 对应发送到相应的 PP stage
  3. 混合使用 P2P 和 Broadcast:P2P 用于特定设备,Broadcast 用于批量传输
  4. 通过通信组隔离:不同的 PP ranks 使用不同的通信组,避免干扰

这种设计使得系统可以高效地在不同并行配置的集群之间同步权重,同时最小化通信开销。

Wiki pages you might want to explore:

  • Parameter Synchronization (alibaba/ROLL)
  • Strategy Pattern and Backend Abstraction (alibaba/ROLL)

Wiki pages you might want to explore:

  • Strategy Pattern and Backend Abstraction (alibaba/ROLL)
  • State Offloading and Memory Management (alibaba/ROLL)

Citations

File: roll/pipeline/rlvr/rlvr_pipeline.py (L309-313)

        self.set_model_update_pair(src_cluster=self.actor_train,tgt_cluster=self.actor_infer,frequency=self.pipeline_config.actor_train.model_update_frequency,)

File: roll/distributed/strategy/strategy.py (L85-118)

    def setup_collective_group(self, model_update_name, comm_plan, backend=None):"""单卡infer strategy可直接复用,多卡infer strategy需要自行管理"""if backend is None:backend = current_platform.communication_backendrank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=self.worker.rank, rank_in_worker=0)if rank is None:logger.info(f"no comm_plan found for rank {self.worker.rank}/{0}")returngroup_name = comm_plan_args["group_name"]master_addr = comm_plan_args["master_addr"]master_port = comm_plan_args["master_port"]world_size = len(comm_plan_args["tgt_devices"]) + 1src_pp_rank = comm_plan_args["src_pp_rank"]logger.info(f"{group_name} rank: {rank} world_size: {world_size}, {comm_plan_args}")collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name, master_addr=master_addr, master_port=master_port)# A small all_reduce for warmup.collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)if model_update_name not in self.model_update_comm_plan:self.model_update_comm_plan[model_update_name] = {}self.model_update_comm_plan[model_update_name][src_pp_rank] = dict(rank=rank,world_size=world_size,src_pp_rank=src_pp_rank,group_name=group_name,comm_plan=comm_plan,comm_plan_args=comm_plan_args,)logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")

File: roll/distributed/strategy/megatron_strategy.py (L494-534)

    def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):comm_plan = self.model_update_comm_plan[model_update_name][self.worker.rank_info.pp_rank]broadcast_time_cost = 0with Timer("model_update_total") as timer_total:for meta_infos, buffer in self.model.all_gather_weights_as_hf_bucket(models=self.models_unwrapped, bucket_size=256 * 1024 * 1024):ray.get(self.barrier.wait.remote())refs = []with Timer("broadcast") as timer_broadcast:for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter_in_bucket.remote(model_update_name=model_update_name,meta_infos=meta_infos, buffer=buffer, ranks_in_worker=[p2p_tgt_device["device"]["rank"]])refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_bucket.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,meta_infos=meta_infos,bucket_size=buffer.numel() * buffer.element_size(),)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)ray.get(self.barrier.wait.remote())broadcast_time_cost += timer_broadcast.lastmetrics = {"all_gather": timer_total.last - broadcast_time_cost,"broadcast": broadcast_time_cost,}return metrics

File: roll/third_party/vllm/worker_helper.py (L65-87)

    def setup_collective_group(self, comm_plan, backend, rank_in_cluster):self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {})rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster,rank_in_worker=dist.get_rank())if rank is None:logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}")returngroup_name = comm_plan_args["group_name"]master_addr = comm_plan_args["master_addr"]master_port = comm_plan_args["master_port"]world_size = len(comm_plan_args["tgt_devices"]) + 1src_pp_rank = comm_plan_args["src_pp_rank"]collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name,master_addr=master_addr, master_port=master_port)# A small all_reduce for warmup.collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)self.model_update_comm_plan[src_pp_rank] = dict(rank=rank,world_size=world_size,src_pp_rank=src_pp_rank,group_name=group_name,comm_plan=comm_plan,comm_plan_args=comm_plan_args)logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")

File: roll/third_party/vllm/worker_helper.py (L89-95)

    def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size):if src_pp_rank not in self.model_update_comm_plan:returncomm_plan = self.model_update_comm_plan[src_pp_rank]buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type)collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"])WorkerHelper.update_parameter_in_bucket(self, meta_infos, buffer, [dist.get_rank()])

File: roll/distributed/strategy/deepspeed_strategy.py (L526-633)

    def model_update(self, model_update_name, tgt_workers, broadcast_tgt_devices, p2p_tgt_devices):model = self.unwrap_model()if is_lora := (self.worker_config.model_args.lora_target is not None):all_params = self.collect_lora_params()peft_config = model.peft_config.get("default", None)else:all_params = list(model.named_parameters())comm_plan = self.model_update_comm_plan[model_update_name][self.worker.rank_info.pp_rank]model = self.unwrap_model()broadcast_time_cost = 0with Timer("model_update_total") as timer_total:for param_name, param in tqdm(all_params, desc="weight update progress", total=len(all_params)):shape = param.shape if not self.ds_config.is_zero3() else param.ds_shapeif not self.ds_config.is_zero3():param_weight = param.datarefs = []for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter.remote(model_update_name=model_update_name,parameter_name=param_name,weight=param_weight,ranks_in_worker=[p2p_tgt_device["device"]["rank"]],is_lora=is_lora,)refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_parameter.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,dtype=param_weight.dtype,shape=shape,parameter_name=param_name,is_lora=is_lora,)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=param_weight, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)else:with GatheredParameters([param]):param_weight = param.datawith Timer("broadcast") as timer_broadcast:refs = []for p2p_tgt_device in p2p_tgt_devices:p2p_tgt_worker = tgt_workers[p2p_tgt_device["rank"]]ref = p2p_tgt_worker.update_parameter.remote(model_update_name=model_update_name,parameter_name=param_name,weight=param_weight,ranks_in_worker=[p2p_tgt_device["device"]["rank"]],is_lora=is_lora,)refs.append(ref)if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):for worker in tgt_workers:ref = worker.broadcast_parameter.remote(model_update_name=model_update_name,src_pp_rank=self.worker.rank_info.pp_rank,dtype=param_weight.dtype,shape=shape,parameter_name=param_name,is_lora=is_lora,)refs.append(ref)if len(broadcast_tgt_devices) > 0:collective.broadcast(tensor=param_weight, src_rank=0, group_name=comm_plan["group_name"])ray.get(refs)broadcast_time_cost += timer_broadcast.lastif is_lora:with Timer("add_lora") as timer_add_lora:if (self.worker.rank_info.tp_rank == 0and self.worker.rank_info.cp_rank == 0and self.worker.rank_info.dp_rank == 0):refs = []for worker in tgt_workers:ref = worker.add_lora.remote(peft_config=asdict(peft_config))refs.append(ref)ray.get(refs)metrics = {"broadcast": broadcast_time_cost,}if is_lora:metrics["all_gather"] = timer_total.last - broadcast_time_cost - timer_add_lora.lastmetrics["add_lora"] = timer_add_lora.lastelse:metrics["all_gather"] = timer_total.last - broadcast_time_costreturn metrics
http://www.dtcms.com/a/573953.html

相关文章:

  • 哪做网站比较便宜外链发布
  • linux之中断子系统介绍(1)
  • 算法 day 45
  • 进入官方网站上海建网站方案
  • ABAP+WHERE字段长度不一致报错解决
  • WHAT - useCallback 深入理解
  • 怎么做自己的网站推广淘宝客怎么样做自己的网站
  • 网站建设大师程序员给别人做的网站违法
  • 文件属性获取与目录IO操作详解
  • 优秀网站首页广东省建设注册中心网站
  • 要将ITP集成到Jenkins Pipeline中,实现开发发版时自动触发自动化测试
  • Linux 定时监测 Java 服务
  • 体外产品的研发网站如何建设paypal网站做外贸
  • 浙江城乡建设局和住建局seo课程培训入门
  • 3系统需求调研项目整合管理
  • Nestjs框架: Consul健康检查与gRPC客户端动态管理优化方案
  • 开机自动启动activity
  • 医学图像分割评价指标Dice与HD95的详解
  • 杀毒软件杀毒原理(草稿)
  • 网站开发需要会的东西网页设计大赛主题
  • 如何将iPhone上的笔记传输到电脑
  • 发布公司信息的网站网推接单
  • MES 离散制造核心流程详解(含关键动作、角色与异常处理)
  • 网站建设方案与报价wordpress文章怎么生成标签
  • 雄安投资建设集团网站东莞网站建设咨询
  • ruoyi前端(vue3)框架,切换菜单白屏问题
  • HTML5+CSS3小实例:不用JS实现幽灵网格动画
  • 人工智能 机器学习 深度学习
  • 用C++从零开始实现的小型深度学习训练框架
  • 算法题(Python)数组篇 | 3.有序数组的平方