当前位置: 首页 > news >正文

详解 torch.distributed.all_gather_into_tensor

文章目录

  • 功能说明
  • 函数原型
  • 代码示例
  • 输入输出说明
  • 关键注意事项

torch.distributed.all_gather_into_tensor 是 PyTorch 分布式库中的一个集合通信函数,用于将多个进程的本地张量聚合到一个全局张量中,每个进程都会得到完整的聚合结果。

all_gather 相比,它直接将结果收集到一个预先分配的输出张量中,更节省内存。

功能说明

  • 作用:在所有参与通信的进程间,将每个进程的本地张量片段聚合到一个全局张量。每个进程都会收到包含所有进程数据的完整张量。
  • 适用场景:需要收集分布式环境中各进程的局部结果(如分片数据、中间计算结果)并拼接成完整数据的场景(例如分布式推理中的结果合并)。

函数原型

torch.distributed.all_gather_into_tensor(output_tensor: torch.Tensor,input_tensor: torch.Tensor,group: Optional[ProcessGroup] = None,async_op: bool = False
)
  • 参数
    • output_tensor:预先分配的输出张量,用于存储所有进程的聚合结果(需在调用前确定正确形状)。
    • input_tensor:当前进程的输入张量(局部数据)。
    • group:通信组(默认使用全局组)。
    • async_op:是否异步执行(默认 False,同步执行)。

代码示例

以下示例展示了在 4 个进程的分布式环境中使用 all_gather_into_tensor 的流程:

import torch
import torch.distributed as dist
import os
import subprocessdef run(rank, size):# 初始化分布式环境os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group('gloo', rank=rank, world_size=size)  # 使用gloo后端(CPU),也可换nccl(GPU)# 每个进程创建本地张量(示例:rank=0→[0,0], rank=1→[1,1], 以此类推)local_tensor = torch.tensor([rank] * 2, dtype=torch.float32)print(f"进程 {rank} 的本地张量: {local_tensor}")# 预先分配输出张量:总大小 = 进程数 × 本地张量大小output_tensor = torch.empty((size, 2), dtype=torch.float32)  # 4个进程×2元素=8元素# 执行all_gather_into_tensor:聚合所有进程的本地张量到output_tensordist.all_gather_into_tensor(output_tensor, local_tensor)# 每个进程都会得到完整的聚合结果print(f"进程 {rank} 的聚合结果: {output_tensor}")def main():size = 4  # 总进程数# 使用subprocess启动多个进程processes = []for rank in range(size):p = subprocess.Popen(['python', '-c', f"from __main__ import run; run({rank}, {size})"])processes.append(p)for p in processes:p.wait()if __name__ == "__main__":main()

输入输出说明

  • 输入:每个进程的 local_tensor 是形状为 (2,) 的张量,内容为 [rank, rank](例如进程 0 的输入是 [0., 0.],进程 1 的输入是 [1., 1.] 等)。
  • 输出:每个进程的 output_tensor 是预先分配的形状为 (4, 2) 的张量,聚合了所有 4 个进程的本地数据,结果为:
    tensor([[0., 0.],[1., 1.],[2., 2.],[3., 3.]])
    

关键注意事项

  1. 输出张量预分配output_tensor 必须在调用前分配好内存,且形状需满足 (world_size, ) + input_tensor.shape(假设按第 0 维度聚合)。
  2. 数据一致性:所有进程的 input_tensor 必须具有相同的形状和数据类型,否则会导致错误。
  3. 后端选择:CPU 环境推荐使用 gloo 后端,GPU 环境推荐使用 nccl 后端以获得更高性能。
  4. 同步性:默认同步执行,函数返回时聚合已完成;若设 async_op=True,需通过返回的 Work 对象调用 wait() 确保完成。

该函数比传统的 all_gather 更高效,因为它避免了中间列表的创建,直接写入预分配的张量,适合大规模数据聚合场景。

http://www.dtcms.com/a/352403.html

相关文章:

  • 15.examples\01-Micropython-Basics\demo_yield_task.py 加强版
  • 【实时Linux实战系列】基于实时Linux的生物识别系统
  • #Linux内存管理学以致用# 请你根据linux 内核struct page 结构体的双字对齐的设计思想,设计一个类似的结构体
  • 【测试需求分析】-需求来源分析(一)
  • 博士招生 | 香港大学 Intelligent Communication Lab 招收全奖博士
  • 【deepseek问答记录】:chatGPT的参数数量和上下文长度有关系吗?
  • AI Agent正在给传统数据仓库下“死亡通知书“
  • 读《精益数据分析》:用户行为热力图
  • 【拍摄学习记录】01-景别
  • 创龙3576ububuntu系统设置静态IP方法
  • 【Linux 进程】进程程序替换详解
  • 8.26网络编程——Modbus TCP
  • Git 高级技巧:利用 Cherry Pick 实现远程仓库的同步合并
  • 【自然语言处理与大模型】微调数据集如何构建
  • docker 的网络
  • shell默认命令替代、fzf
  • RCC_APB2PeriphClockCmd
  • sdi开发说明
  • 推荐系统王树森(三)粗排精排
  • STM32的Sg90舵机
  • Python入门教程之字符串类型
  • 日语学习-日语知识点小记-构建基础-JLPT-N3阶段(20):文法+单词第7回2
  • iPhone 17 Pro 全新配色确定,首款折叠屏 iPhone 将配备 Touch ID 及四颗镜头
  • 【测试需求分析】-需求类型的初步分析(二)
  • 【NuGet】引用nuget包后构建项目简单解析
  • day41-动静分离
  • 数字时代下的智能信息传播引擎
  • 仿真干货|解析Abaqus AMD的兼容与并行效率问题
  • 基于硅基流动API构建智能聊天应用的完整指南
  • 使用QML的Rectangle组件的边框属性