「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理
「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理
大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.py 和 distributed/comm_group.py 模块负责。
核心概念:CommGroup
distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。
- 初始化:
CommGroup根据传入的rank_list(一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局rank来创建对应的ProcessGroup。 - 关键属性:
group: 底层的 PyTorchProcessGroup。cpu_group: 对应的 CPUProcessGroup(用于 CPU 上的集合通信)。ranks_in_group: 当前CommGroup包含的所有 Rank 列表。group_size: 当前进程所在通信组的大小。rank_in_group: 当前进程在所在通信组内的局部 Rank。is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
- 通信操作封装: 提供了对
torch.distributed常用通信原语(如broadcast,all_reduce,all_gather,reduce_scatter等)的封装,自动传入正确的group参数。
CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。
并行状态管理 (parallel_state.py)
distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。
- 全局变量: 定义了
_WORLD_GROUP,_TP_GROUP,_PP_GROUP,_DP_GROUP,_EP_GROUP等全局变量,用于存储各个并行维度的CommGroup实例。 - 初始化函数 (
initialize_parallel_groups): 这是并行设置的核心入口。- 输入: TP, PP, DP, EP 的大小 (
tp_size,pp_size,dp_size,ep_size)。 - 获取环境信息: 获取全局
rank,local_rank,world_size。 - 按序初始化: 依次调用
initialize_world_group,initialize_tp_group,initialize_pp_group,initialize_dp_group,initialize_ep_group。 - 初始化逻辑: 每个
initialize_*_group函数根据并行维度的大小和当前rank计算出该维度对应的rank_list,然后创建CommGroup实例并赋值给相应的全局变量。例如:initialize_tp_group:world_size被划分为world_size // tp_size个 TP 组,每个组包含tp_size个连续的 Rank。initialize_pp_group:world_size被划分为world_size // pp_size个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为num_pp_groups的 Rank。initialize_dp_group: 类似 PP 组的划分方式。initialize_ep_group: 逻辑稍复杂。如果ep_size > 1:- 若
tp_size == ep_size且dp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。 - 若
dp_size == ep_size且tp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。 - 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
- 如果
ep_size == 1,则每个 Rank 自己构成一个 EP 组。
- 若
- 特殊处理:
initialize_pp_group中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
- 输入: TP, PP, DP, EP 的大小 (
- 访问接口: 提供
get_world_group(),get_tp_group(),get_pp_group(),get_dp_group(),get_ep_group(),get_tp_size(),get_dp_size(),get_ep_size()等函数,方便全局访问并行状态信息和通信组。 - 销毁:
destroy_parallel_groups()负责销毁创建的通信组。
使用流程
- 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
- 调用
initialize_parallel_groups初始化所有并行通信组。 - 在模型代码或算子实现中,通过
get_tp_group(),get_ep_group()等接口获取相应的CommGroup。 - 调用
CommGroup实例提供的通信方法(如tp_group.all_reduce(tensor))执行集合通信。
总结
「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。# 「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理
大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.py 和 distributed/comm_group.py 模块负责。
核心概念:CommGroup
distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。
- 初始化:
CommGroup根据传入的rank_list(一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局rank来创建对应的ProcessGroup。 - 关键属性:
group: 底层的 PyTorchProcessGroup。cpu_group: 对应的 CPUProcessGroup(用于 CPU 上的集合通信)。ranks_in_group: 当前CommGroup包含的所有 Rank 列表。group_size: 当前进程所在通信组的大小。rank_in_group: 当前进程在所在通信组内的局部 Rank。is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
- 通信操作封装: 提供了对
torch.distributed常用通信原语(如broadcast,all_reduce,all_gather,reduce_scatter等)的封装,自动传入正确的group参数。
CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。
并行状态管理 (parallel_state.py)
distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。
- 全局变量: 定义了
_WORLD_GROUP,_TP_GROUP,_PP_GROUP,_DP_GROUP,_EP_GROUP等全局变量,用于存储各个并行维度的CommGroup实例。 - 初始化函数 (
initialize_parallel_groups): 这是并行设置的核心入口。- 输入: TP, PP, DP, EP 的大小 (
tp_size,pp_size,dp_size,ep_size)。 - 获取环境信息: 获取全局
rank,local_rank,world_size。 - 按序初始化: 依次调用
initialize_world_group,initialize_tp_group,initialize_pp_group,initialize_dp_group,initialize_ep_group。 - 初始化逻辑: 每个
initialize_*_group函数根据并行维度的大小和当前rank计算出该维度对应的rank_list,然后创建CommGroup实例并赋值给相应的全局变量。例如:initialize_tp_group:world_size被划分为world_size // tp_size个 TP 组,每个组包含tp_size个连续的 Rank。initialize_pp_group:world_size被划分为world_size // pp_size个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为num_pp_groups的 Rank。initialize_dp_group: 类似 PP 组的划分方式。initialize_ep_group: 逻辑稍复杂。如果ep_size > 1:- 若
tp_size == ep_size且dp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。 - 若
dp_size == ep_size且tp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。 - 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
- 如果
ep_size == 1,则每个 Rank 自己构成一个 EP 组。
- 若
- 特殊处理:
initialize_pp_group中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
- 输入: TP, PP, DP, EP 的大小 (
- 访问接口: 提供
get_world_group(),get_tp_group(),get_pp_group(),get_dp_group(),get_ep_group(),get_tp_size(),get_dp_size(),get_ep_size()等函数,方便全局访问并行状态信息和通信组。 - 销毁:
destroy_parallel_groups()负责销毁创建的通信组。
使用流程
- 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
- 调用
initialize_parallel_groups初始化所有并行通信组。 - 在模型代码或算子实现中,通过
get_tp_group(),get_ep_group()等接口获取相应的CommGroup。 - 调用
CommGroup实例提供的通信方法(如tp_group.all_reduce(tensor))执行集合通信。
总结
「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。
