详解 torch.distributed.all_gather_into_tensor
文章目录
- 功能说明
- 函数原型
- 代码示例
- 输入输出说明
- 关键注意事项
torch.distributed.all_gather_into_tensor
是 PyTorch 分布式库中的一个集合通信函数,用于将多个进程的本地张量聚合到一个全局张量中,每个进程都会得到完整的聚合结果。
与 all_gather
相比,它直接将结果收集到一个预先分配的输出张量中,更节省内存。
功能说明
- 作用:在所有参与通信的进程间,将每个进程的本地张量片段聚合到一个全局张量。每个进程都会收到包含所有进程数据的完整张量。
- 适用场景:需要收集分布式环境中各进程的局部结果(如分片数据、中间计算结果)并拼接成完整数据的场景(例如分布式推理中的结果合并)。
函数原型
torch.distributed.all_gather_into_tensor(output_tensor: torch.Tensor,input_tensor: torch.Tensor,group: Optional[ProcessGroup] = None,async_op: bool = False
)
- 参数:
output_tensor
:预先分配的输出张量,用于存储所有进程的聚合结果(需在调用前确定正确形状)。input_tensor
:当前进程的输入张量(局部数据)。group
:通信组(默认使用全局组)。async_op
:是否异步执行(默认False
,同步执行)。
代码示例
以下示例展示了在 4 个进程的分布式环境中使用 all_gather_into_tensor
的流程:
import torch
import torch.distributed as dist
import os
import subprocessdef run(rank, size):# 初始化分布式环境os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group('gloo', rank=rank, world_size=size) # 使用gloo后端(CPU),也可换nccl(GPU)# 每个进程创建本地张量(示例:rank=0→[0,0], rank=1→[1,1], 以此类推)local_tensor = torch.tensor([rank] * 2, dtype=torch.float32)print(f"进程 {rank} 的本地张量: {local_tensor}")# 预先分配输出张量:总大小 = 进程数 × 本地张量大小output_tensor = torch.empty((size, 2), dtype=torch.float32) # 4个进程×2元素=8元素# 执行all_gather_into_tensor:聚合所有进程的本地张量到output_tensordist.all_gather_into_tensor(output_tensor, local_tensor)# 每个进程都会得到完整的聚合结果print(f"进程 {rank} 的聚合结果: {output_tensor}")def main():size = 4 # 总进程数# 使用subprocess启动多个进程processes = []for rank in range(size):p = subprocess.Popen(['python', '-c', f"from __main__ import run; run({rank}, {size})"])processes.append(p)for p in processes:p.wait()if __name__ == "__main__":main()
输入输出说明
- 输入:每个进程的
local_tensor
是形状为(2,)
的张量,内容为[rank, rank]
(例如进程 0 的输入是[0., 0.]
,进程 1 的输入是[1., 1.]
等)。 - 输出:每个进程的
output_tensor
是预先分配的形状为(4, 2)
的张量,聚合了所有 4 个进程的本地数据,结果为:tensor([[0., 0.],[1., 1.],[2., 2.],[3., 3.]])
关键注意事项
- 输出张量预分配:
output_tensor
必须在调用前分配好内存,且形状需满足(world_size, ) + input_tensor.shape
(假设按第 0 维度聚合)。 - 数据一致性:所有进程的
input_tensor
必须具有相同的形状和数据类型,否则会导致错误。 - 后端选择:CPU 环境推荐使用
gloo
后端,GPU 环境推荐使用nccl
后端以获得更高性能。 - 同步性:默认同步执行,函数返回时聚合已完成;若设
async_op=True
,需通过返回的Work
对象调用wait()
确保完成。
该函数比传统的 all_gather
更高效,因为它避免了中间列表的创建,直接写入预分配的张量,适合大规模数据聚合场景。