当前位置: 首页 > news >正文

PyTorch API 4 - 分布式通信、分布式张量

文章目录

  • 分布式通信包 - torch.distributed
    • 后端支持
      • PyTorch 内置的后端
      • 选择哪个后端?
      • 常见环境变量
        • 选择使用的网络接口
        • 其他NCCL环境变量
    • 基础概念
    • 初始化
    • 返回类型:`bool`
      • TCP初始化
      • 共享文件系统初始化
      • 环境变量初始化方法
    • 初始化后操作
    • 关闭处理
      • 重新初始化
    • DeviceMesh
    • 点对点通信
    • 同步与异步集合操作
    • 集合函数
    • 分布式键值存储
    • 分析集体通信性能
    • 多GPU集合函数
    • 第三方后端
    • 启动工具
    • 生成进程工具
    • 调试 `torch.distributed` 应用程序
      • Python 断点调试
      • 监控式屏障
  • 监控屏障需要 gloo 进程组执行主机端同步
      • `TORCH_DISTRIBUTED_DEBUG`
    • 日志记录
  • torch.distributed.tensor
    • PyTorch DTensor(分布式张量)
      • DTensor 类 API
      • 作为分布式通信器的DeviceMesh
      • DTensor 布局类型
    • 创建 DTensor 的不同方式
      • 从逻辑上的 torch.Tensor 创建 DTensor
      • DTensor 工厂函数
    • 调试
      • 日志记录
      • 调试工具
    • 实验性功能
  • torch.distributed.tensor
    • PyTorch DTensor(分布式张量)
      • DTensor 类 API
      • 作为分布式通信器的DeviceMesh
      • DTensor 布局类型
    • Different ways to create a DTensor
      • Create DTensor from a logical torch.Tensor
      • DTensor 工厂函数
    • 调试
      • 日志记录
      • 调试工具
    • Experimental Features
  • 通用Join上下文管理器
  • Torch Distributed Elastic
    • 快速开始
    • 文档


分布式通信包 - torch.distributed


注意:关于分布式训练相关功能的简要介绍,请参阅PyTorch分布式概述。


后端支持

torch.distributed 支持三种内置后端,每种后端具有不同的功能特性。下表展示了哪些功能可用于 CPU/CUDA 张量。
注意:MPI 仅在用于构建 PyTorch 的实现支持 CUDA 时,才能启用 CUDA 功能。

后端gloompinccl
设备类型CPUGPUCPUGPUCPUGPU
发送?
接收?
广播?
全归约?
归约?
全收集?
收集?
分散?
归约分散
全到全?
屏障?

PyTorch 内置的后端

PyTorch 分布式包支持 Linux(稳定版)、MacOS(稳定版)和 Windows(原型版)。在 Linux 平台上,默认会构建并包含 Gloo 和 NCCL 后端(NCCL 仅在 CUDA 环境下构建时包含)。MPI 是一个可选后端,只有从源码构建 PyTorch 时才能包含(例如在已安装 MPI 的主机上构建 PyTorch)。


注意:从 PyTorch v1.8 开始,Windows 支持除 NCCL 之外的所有集体通信后端。如果 init_process_group() 的 init_method 参数指向文件,则必须遵循以下格式:

  • 本地文件系统:init_method="file:///d:/tmp/some_file"
  • 共享文件系统:init_method="file://{machine_name}/{share_folder_name}/some_file"

与 Linux 平台相同,您可以通过设置环境变量 MASTER_ADDR 和 MASTER_PORT 来启用 TcpStore。


选择哪个后端?

过去我们经常被问到:“我应该使用哪个后端?”

  • 经验法则
    • 分布式 GPU 训练使用 NCCL 后端
    • 分布式 CPU 训练使用 Gloo 后端
  • 配备 InfiniBand 互连的 GPU 主机
    • 使用 NCCL,因为它是目前唯一支持 InfiniBand 和 GPUDirect 的后端
  • 配备以太网互连的 GPU 主机
    • 使用 NCCL,因为它目前能提供最佳的分布式 GPU 训练性能,尤其适用于多进程单节点或多节点分布式训练。如果遇到 NCCL 相关问题,可将 Gloo 作为备选方案。(注意:当前 Gloo 在 GPU 上的运行速度慢于 NCCL)
  • 配备 InfiniBand 互连的 CPU 主机
    • 若 InfiniBand 已启用 IP over IB 功能则使用 Gloo,否则改用 MPI。我们计划在后续版本中为 Gloo 添加 InfiniBand 支持
  • 配备以太网互连的 CPU 主机
    • 除非有特殊需求需使用 MPI,否则默认选择 Gloo

常见环境变量


选择使用的网络接口

默认情况下,NCCL 和 Gloo 后端都会尝试自动选择合适的网络接口。如果自动检测的接口不正确,可以通过以下环境变量手动指定(分别对应各自的后端):

  • NCCL_SOCKET_IFNAME,例如 export NCCL_SOCKET_IFNAME=eth0
  • GLOO_SOCKET_IFNAME,例如 export GLOO_SOCKET_IFNAME=eth0

如果使用 Gloo 后端,可以通过逗号分隔指定多个接口,例如:export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3。后端会以轮询方式在这些接口间分配操作。必须确保所有进程在该变量中指定相同数量的接口


其他NCCL环境变量

调试功能 - 当NCCL出现故障时,可设置NCCL_DEBUG=INFO来打印明确的警告信息以及基础的NCCL初始化信息。

您还可以使用NCCL_DEBUG_SUBSYS获取NCCL特定模块的详细日志。例如,设置NCCL_DEBUG_SUBSYS=COLL将打印集合通信调用的日志,这对调试卡死问题(特别是由集合操作类型或消息大小不匹配引发的问题)很有帮助。若遇拓扑结构检测失败的情况,设置NCCL_DEBUG_SUBSYS=GRAPH可查看详细检测结果,如需NCCL团队进一步协助,该日志可作为参考依据保存。

性能调优 - NCCL基于拓扑检测结果进行自动调优以减少用户工作量。在某些基于socket的系统中,用户仍可尝试调整NCCL_SOCKET_NTHREADSNCCL_NSOCKS_PERTHREAD来提升socket网络带宽。这两个环境变量已在AWS、GCP等云服务商环境中经过NCCL预调优。

完整NCCL环境变量列表请参阅NVIDIA NCCL官方文档


基础概念

torch.distributed 包为 PyTorch 提供了跨多个计算节点(运行在一台或多台机器上)的多进程并行支持及通信原语。torch.nn.parallel.DistributedDataParallel() 类基于此功能,通过封装任意 PyTorch 模型来提供同步分布式训练。这与 Multiprocessing package - torch.multiprocessing 和 torch.nn.DataParallel() 提供的并行方式不同,因为它支持多台网络连接的机器,并且需要用户显式地为每个进程启动主训练脚本的独立副本。

在单机同步场景下,torch.distributedtorch.nn.parallel.DistributedDataParallel() 封装器相比其他数据并行方法(包括 torch.nn.DataParallel())仍具有优势:

  • 独立优化器:每个进程维护自己的优化器,并在每次迭代中执行完整的优化步骤。虽然这看似冗余(因为梯度已在进程间收集并平均,各进程梯度相同),但省去了参数广播步骤,从而减少了节点间张量传输的时间开销。
  • 独立 Python 解释器:每个进程拥有独立的 Python 解释器,避免了单 Python 进程中驱动多个执行线程、模型副本或 GPU 时产生的额外解释器开销和 “GIL 争用”。这对于重度依赖 Python 运行时的模型(如包含循环层或大量小组件的模型)尤为重要。

初始化

在使用其他方法之前,需要通过 torch.distributed.init_process_group()torch.distributed.device_mesh.init_device_mesh() 函数初始化该包。这两个函数都会阻塞,直到所有进程都加入为止。


警告:初始化操作不是线程安全的。进程组的创建应在单一线程中执行,以防止不同进程间出现不一致的 ‘UUID’ 分配,并避免初始化期间的竞争条件导致程序挂起。


torch.distributed.is_available()

如果分布式包可用则返回 True

否则,torch.distributed 不会暴露任何其他 API。目前 torch.distributed 在 Linux、MacOS 和 Windows 平台上可用。若要从源码构建 PyTorch 时启用该功能,需设置:

USE_DISTRIBUTED=1

当前默认值为:Linux 和 Windows 系统下 USE_DISTRIBUTED=1,MacOS 系统下 USE_DISTRIBUTED=0

返回类型:bool


torch.distributed.init_process_group(backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name='', pg_options=None, device_id=None)

初始化默认的分布式进程组。

这将同时初始化分布式包。

初始化进程组主要有两种方式:

1、显式指定 storerankworld_size

2、指定 init_method(URL字符串)来指示如何发现对等节点。可选指定 rankworld_size,或将所有必需参数编码在URL中并省略它们

如果均未指定,则默认 init_method 为 “env://”。

参数说明

  • backend (str 或 Backend, 可选) - 使用的后端。根据构建配置,有效值包括 mpi、gloo、nccl、ucc 或第三方插件注册的后端。从 2.6 版本开始,若未提供 backend,c10d 将根据 device_id 参数(如提供)对应的设备类型使用注册的后端。当前已知的默认注册为:cuda 设备使用 nccl,cpu 设备使用 gloo。若 backend 和 device_id 均未提供,c10d 将自动检测运行机器的加速器并使用对应注册的后端(或 cpu)。该字段可接受小写字符串(如 “gloo”),也可通过 Backend 属性访问(如 Backend.GLOO)。注意:使用 nccl 后端时,若单机多进程,每个进程必须独占其使用的 GPU,进程间共享 GPU 可能导致死锁或 NCCL 非法使用。ucc 后端为实验性功能。
  • init_method (str, 可选) - 指定进程组初始化方式的 URL。若未指定 init_method 或 store,默认为 “env://”。与 store 参数互斥。
  • world_size (int, 可选) - 参与任务的进程总数。若指定 store 则必须提供。
  • rank (int, 可选) - 当前进程的排名(取值范围应为 0 到 world_size-1)。若指定 store 则必须提供。
  • store (Store, 可选) - 所有工作进程可访问的键值存储,用于交换连接/地址信息。与 init_method 互斥。
  • timeout (timedelta, 可选) - 进程组操作的超时时间。NCCL 默认 10 分钟,其他后端默认 30 分钟。超时后异步中止集合操作并终止进程。由于 CUDA 执行是异步的,继续执行用户代码可能不安全,因为失败的异步 NCCL 操作可能导致后续 CUDA 操作处理损坏数据。当设置 TORCH_NCCL_BLOCKING_WAIT 时,进程将阻塞等待此超时。
  • group_name (str, 可选, 已弃用) - 组名(该参数已被忽略)
  • pg_options (ProcessGroupOptions, 可选) - 进程组选项,用于在构建特定进程组时传递额外参数。目前仅支持 nccl 后端的 ProcessGroupNCCL.Options,可指定 is_high_priority_stream 让 nccl 后端在有计算内核等待时选择高优先级 CUDA 流。其他可配置 NCCL 的选项参见:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
  • device_id (torch.device, 可选) - 绑定进程的特定设备,支持后端特定优化。目前仅在 NCCL 下有两个效果:立即形成通信器(直接调用 ncclCommInit* 而非延迟调用),子组尽可能使用 ncclCommSplit 以避免不必要的组创建开销。如需提前获知 NCCL 初始化错误,也可使用此字段。

注意事项

启用 backend == Backend.MPI 需在支持 MPI 的系统上从源码编译 PyTorch。

实验性说明

多后端支持目前处于实验阶段。未指定 backend 时,将同时创建 gloonccl 后端:CPU 张量的集合操作使用 gloo,CUDA 张量的集合操作使用 nccl。可通过格式为 “<设备类型>:<后端名称>,<设备类型>:<后端名称>” 的字符串指定自定义后端,例如:“cpu:gloo,cuda:custom_backend”。


torch.distributed.device_mesh.init_device_mesh(device_type, mesh_shape, *, mesh_dim_names=None)

根据device_typemesh_shapemesh_dim_names参数初始化一个DeviceMesh。

这会创建一个具有n维数组布局的DeviceMesh,其中n是mesh_shape的长度。

如果提供了mesh_dim_names,则每个维度会被标记为mesh_dim_names[i]

注意init_device_mesh遵循SPMD编程模型,意味着相同的PyTorch Python程序会在集群中的所有进程/rank上运行。请确保mesh_shape(描述设备布局的n维数组的维度)在所有rank上保持一致。不一致的mesh_shape可能导致程序挂起。

注意:如果找不到进程组,init_device_mesh会在后台初始化分布式通信所需的分布式进程组/组。

参数

  • device_type (str) - 网格的设备类型。当前支持:“cpu”、“cuda/cuda-like”。不允许传入带有GPU索引的设备类型,如"cuda:0"。
  • mesh_shape (Tuple[int]) - 定义描述设备布局的多维数组维度的元组。
  • mesh_dim_names (Tuple[str], 可选) - 分配给描述设备布局的多维数组每个维度的网格维度名称元组。其长度必须与mesh_shape的长度匹配。mesh_dim_names中的每个字符串必须是唯一的。

返回

一个表示设备布局的DeviceMesh对象。

返回类型

DeviceMesh

示例:


>>> from torch.distributed.device_mesh import init_device_mesh
>>> >
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))

torch.distributed.is_initialized()

检查默认进程组是否已初始化。

返回类型:bool


torch.distributed.is_mpi_available()

检查 MPI 后端是否可用。

返回类型:bool


torch.distributed.is_nccl_available()

检查NCCL后端是否可用。

返回类型:bool


torch.distributed.is_gloo_available()

检查 Gloo 后端是否可用。

返回类型:bool


torch.distributed.distributed_c10d.is_xccl_available()

检查XCCL后端是否可用。

返回类型:bool


torch.distributed.is_torchelastic_launched()

检查当前进程是否通过 torch.distributed.elastic(即 torchelastic)启动。

通过检测环境变量 TORCHELASTIC_RUN_ID 是否存在作为判断依据。这是一个合理的代理指标,因为 TORCHELASTIC_RUN_ID 映射到 rendezvous id(该值始终为非空,用于标识作业ID以实现节点发现)。

返回类型:bool

目前支持三种初始化方法:

TCP初始化

有两种使用TCP进行初始化的方式,两者都需要一个所有进程均可访问的网络地址和指定的world_size。第一种方式要求指定一个属于rank 0进程的地址。这种初始化方法要求所有进程都手动指定rank。

请注意,最新版本的分布式包不再支持多播地址。group_name参数也已弃用。


import torch.distributed as dist# Use address of one of the machines
dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456',   rank=args.rank, world_size=4)

共享文件系统初始化

另一种初始化方法利用了组内所有机器均可访问的共享文件系统,并配合指定的world_size参数。URL应以file://开头,并指向共享文件系统中某个不存在文件(位于已存在的目录)的路径。文件系统初始化会自动创建该文件(若不存在),但不会删除文件。因此,您需要确保在下一次对相同文件路径/名称调用init_process_group()前清理该文件。

请注意,最新版分布式包已不再支持自动分配rank,同时group_name参数也已弃用。


警告:此方法假设文件系统支持通过fcntl进行锁定——大多数本地系统和NFS都支持此功能。


警告:此方法总会创建文件,并会在程序结束时尽力清理和删除文件。换句话说,每次使用文件初始化方法时都需要一个全新的空文件才能成功初始化。如果重复使用前次初始化未清理的同一文件,将导致意外行为,通常会造成死锁和故障。因此,尽管该方法会尽力清理文件,但如果自动删除失败,您必须确保在训练结束后删除该文件,以防下次重复使用同一文件。当您计划对同一文件名多次调用init_process_group()时,这一点尤为重要。

简而言之,如果文件未被移除/清理,而您再次对该文件调用init_process_group(),预期会发生故障。经验法则是:确保每次调用init_process_group()时,目标文件不存在或是空文件。


import torch.distributed as dist# rank should always be specified
dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile',   world_size=4, rank=args.rank)

环境变量初始化方法

该方法会从环境变量中读取配置,允许用户完全自定义信息的获取方式。需要设置的环境变量包括:

  • MASTER_PORT - 必填;必须是 rank 0 机器上的空闲端口
  • MASTER_ADDR - 必填(rank 0 除外);rank 0 节点的地址
  • WORLD_SIZE - 必填;可以在此处设置,也可以在初始化函数调用时设置
  • RANK - 必填;可以在此处设置,也可以在初始化函数调用时设置

rank 为 0 的机器将用于建立所有连接。

这是默认的初始化方法,意味着无需指定 init_method(或可设为 env://)。


初始化后操作

运行 torch.distributed.init_process_group() 后,即可使用以下函数。要检查进程组是否已完成初始化,请调用 torch.distributed.is_initialized()


class torch.distributed.Backend(name)

一个类似枚举的后端类。

可用后端类型:GLOO、NCCL、UCC、MPI、XCCL 以及其他已注册的后端。

该类的值为小写字符串,例如 "gloo"。可以通过属性访问,例如 Backend.NCCL

此类可直接调用来解析字符串,例如 Backend(backend_str) 会检查 backend_str 是否有效,若有效则返回解析后的小写字符串。它也接受大写字符串,例如 Backend("GLOO") 会返回 "gloo"

注意:条目 Backend.UNDEFINED 存在但仅用作某些字段的初始值。用户既不应直接使用它,也不应假定其存在。


CLASSMETHOD register_backend(name, func, extended_api=False, devices=None)

使用给定的名称和实例化函数注册一个新的后端。

这个类方法被第三方 ProcessGroup 扩展用于注册新的后端。

参数

  • name (str)ProcessGroup 扩展的后端名称。它应该与 init_process_group() 中的名称匹配。
  • func (function) – 实例化后端的函数处理程序。该函数应在后端扩展中实现,并接受四个参数,包括 storerankworld_sizetimeout
  • extended_api ([bool], 可选) – 后端是否支持扩展参数结构。默认值:False。如果设置为 True,后端将获得一个 c10d::DistributedBackendOptions 实例,以及一个由后端实现定义的进程组选项对象。
  • device (str 或 str 列表, 可选) – 该后端支持的设备类型,例如 “cpu”、“cuda” 等。如果为 None,则假定同时支持 “cpu” 和 “cuda”。

注意:对第三方后端的支持目前处于实验阶段,可能会发生变化。


torch.distributed.get_backend(group=None)

返回给定进程组的后端。

参数

  • group (ProcessGroup, 可选) – 要操作的进程组。默认为通用的主进程组。如果指定了其他特定组,调用进程必须是该group的成员。

返回值:以小写字符串形式返回给定进程组的后端。

返回类型:Backend


torch.distributed.get_rank(group=None)

返回当前进程在指定group中的排名,若无指定则返回默认值。

排名是分布式进程组中分配给每个进程的唯一标识符。这些排名始终是从0到world_size的连续整数。

参数

  • group (ProcessGroup, 可选) – 要操作的进程组。如果为None,则使用默认进程组。

返回值:进程组的排名

  • 如果不在该组中,则返回-1

返回类型:int


torch.distributed.get_world_size(group=None)

返回当前进程组中的进程数量。

参数

  • group (ProcessGroup, 可选) – 要操作的进程组。如果为None,则使用默认进程组。

返回值:进程组的全局大小

如果不在该组中,则返回-1

返回类型:int


关闭处理

在程序退出时,通过调用destroy_process_group()来清理资源非常重要。

推荐遵循的最简单模式是:在训练脚本中不再需要通信的地方(通常是在main()函数末尾附近),通过调用destroy_process_group()并保持group参数为默认值None,来销毁所有进程组和后端。每个训练器进程应该调用一次,而不是在外部的进程启动器层面调用。

如果在超时时间内,某个进程组(pg)中的所有rank都没有调用destroy_process_group(),特别是当应用中存在多个进程组时(例如用于N维并行的情况),可能会导致程序退出时挂起。这是因为ProcessGroupNCCL的析构函数会调用ncclCommAbort,而这个调用必须是集体操作,但如果由Python的垃圾回收器触发ProcessGroupNCCL析构函数的调用顺序是不确定的。显式调用destroy_process_group()可以确保所有rank以一致的顺序调用ncclCommAbort,并避免在ProcessGroupNCCL析构期间调用ncclCommAbort。


重新初始化

destroy_process_group 也可用于销毁单个进程组。一个典型应用场景是容错训练,其中进程组可能在运行时被销毁后重新初始化。这种情况下,关键是在调用销毁操作之后、重新初始化之前,通过非torch.distributed原语的其他方式同步训练器进程。由于实现此类同步的复杂性,该行为目前处于未支持/未测试状态,属于已知问题。若此场景对您造成阻碍,请提交GitHub issue或RFC。


默认情况下,集合操作作用于默认组(也称为全局组),并要求所有进程都参与分布式函数调用。然而,某些工作负载可能受益于更细粒度的通信。这正是分布式组发挥作用的地方。new_group() 函数可用于创建包含任意进程子集的新组。该函数返回一个不透明的组句柄,可作为 group 参数传递给所有集合操作(集合操作是指那些以特定编程模式交换信息的分布式函数)。


torch.distributed.new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None, device_id=None)

创建一个新的分布式进程组。

该函数要求主进程组中的所有进程(即参与分布式作业的所有进程)都必须进入此函数,即使它们不会成为该组的成员。此外,所有进程必须以相同的顺序创建进程组。

警告:安全并发使用规范:

当使用NCCL后端的多进程组时,用户必须确保所有进程间集合操作的执行顺序全局一致。

如果单个进程内的多个线程发起集合操作,需要通过显式同步来确保执行顺序的一致性。

使用torch.distributed异步通信API时,会返回一个工作对象,通信内核会被放入独立的CUDA流中,从而实现通信与计算的重叠。当一个进程组发起一个或多个异步操作后,必须通过调用work.wait()与其他CUDA流同步,才能使用另一个进程组。

详见《并发使用多个NCCL通信器》。

参数说明

  • ranks (list[int]) - 组成员rank列表。若为None则包含所有rank,默认为None
  • timeout (timedelta, 可选) - 超时设置,详见init_process_group说明
  • backend (str 或 [Backend](https://pytorch.org/docs/stable/data.html#torch.distributed.Backend "torch.distributed.Backend"), 可选) - 使用的后端。根据构建配置可选gloonccl,默认使用全局组的后端。应传入小写字符串(如"gloo"),也可通过Backend属性指定(如Backend.GLOO)。传入None时将使用默认进程组的后端
  • pg_options (ProcessGroupOptions, 可选) - 进程组配置选项,用于指定特殊参数。例如对nccl后端可设置is_high_priority_stream来启用高优先级CUDA流。其他NCCL配置选项参见类型文档
  • use_local_synchronization ([bool], 可选) - 在进程组创建结束时执行组内局部屏障。与非成员rank不同,这些rank无需调用API且不参与屏障
  • group_desc (str, 可选) - 进程组的描述字符串
  • device_id (torch.device, 可选) - 要绑定的特定设备。若指定此参数,new_group会立即尝试初始化该设备的通信后端

返回值

返回分布式组的句柄,可用于集合调用。若当前rank不在ranks中则返回GroupMember.NON_GROUP_MEMBER

注意事项

1、use_local_synchronization不兼容MPI后端

2、在大型集群和小型进程组中使用use_local_synchronization=True可能显著提升性能,但需注意这会改变集群行为(非成员rank不参与屏障)

3、当各rank创建多个重叠进程组时,use_local_synchronization=True可能导致死锁。为避免此问题,需确保所有rank遵循相同的全局创建顺序


torch.distributed.get_group_rank(group, global_rank)

将全局排名转换为组内排名。

如果 global_rank 不属于 group 的成员,此操作会抛出 RuntimeError。

参数

  • group (ProcessGroup) – 用于查找相对排名的进程组。
  • global_rank (int) – 要查询的全局排名。

返回值

返回 global_rank 相对于 group 的组内排名

返回类型

int

注意:在默认进程组上调用此函数会返回原值


torch.distributed.get_global_rank(group, group_rank)

将组内排名转换为全局排名。

如果 group_rank 不属于该组,将抛出 RuntimeError。

参数

  • group (ProcessGroup) – 用于查询全局排名的进程组。
  • group_rank ( int ) – 需要查询的组内排名。

返回值:group_rank 相对于 group 的全局排名

返回类型:int

注意:在默认进程组上调用此函数将返回原值


torch.distributed.get_process_group_ranks(group)

获取与group关联的所有排名。

参数

  • group (ProcessGroup) – 要从中获取所有排名的ProcessGroup。

返回值:按组内排名排序的全局排名列表。

返回类型:list [int]


DeviceMesh

DeviceMesh 是一种更高层次的抽象,用于管理进程组(或 NCCL 通信器)。它允许用户轻松创建节点间和节点内的进程组,而无需关心如何为不同的子进程组正确设置 ranks,并帮助轻松管理这些分布式进程组。可以通过 init_device_mesh() 函数创建新的 DeviceMesh,其中 mesh shape 参数用于描述设备拓扑结构。


class torch.distributed.device_mesh.DeviceMesh(device_type, mesh, *, mesh_dim_names=None, _init_backend=True)

DeviceMesh 表示一个设备网格,其中设备的布局可以表示为一个 n 维数组,该 n 维数组的每个值是默认进程组 ranks 的全局 ID。

DeviceMesh 可用于描述集群中设备的布局,并作为集群内设备列表间通信的代理。

DeviceMesh 可用作上下文管理器。

注意:DeviceMesh 遵循 SPMD 编程模型,这意味着相同的 PyTorch Python 程序会在集群中的所有进程/ranks 上运行。因此,用户需要确保描述设备布局的网格数组在所有 ranks 上保持一致。不一致的网格会导致静默挂起。

参数

  • device_type (str) – 网格的设备类型。当前支持:“cpu”、“cuda/cuda-like”。
  • mesh (ndarray) – 描述设备布局的多维数组或整数张量,其中 ID 是默认进程组的全局 ID。

返回

一个表示设备布局的 DeviceMesh 对象。

返回类型:DeviceMesh

以下程序以 SPMD 方式在每个进程/rank 上运行。在此示例中,我们有 2 台主机,每台主机有 4 个 GPU。

在网格的第一个维度上进行归约操作会跨列 (0, 4), … 和 (3, 7) 进行,在网格的第二个维度上进行归约操作会跨行 (0, 1, 2, 3) 和 (4, 5, 6, 7) 进行。

示例:


>>> from torch.distributed.device_mesh import DeviceMesh
>>> >
>>> # Initialize device mesh as (2, 4) to represent the topology
>>> # of cross-host(dim 0), and within-host (dim 1).
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])

static from_group(group, device_type, mesh=None, *, mesh_dim_names=None)

基于现有的ProcessGroup或一组ProcessGroup列表,构造指定device_typeDeviceMesh

构造的设备网格维度数与传入的进程组数量相同。例如:

  • 传入单个进程组时,生成1D网格
  • 传入2个进程组列表时,生成2D网格

当传入多个进程组时,必须提供meshmesh_dim_names参数。进程组的传入顺序决定网格拓扑结构,例如第一个进程组对应DeviceMesh的第0维度。

传入的mesh张量必须满足:

1、维度数与进程组数量相同

2、张量维度顺序与进程组传入顺序一致

参数说明

  • group (ProcessGroup* 或 list[ProcessGroup]) - 现有进程组或进程组列表
  • device_type (str) - 网格设备类型,当前支持:“cpu”、“cuda/cuda-like”。禁止传入带GPU索引的类型(如"cuda:0")
  • mesh (torch.Tensor 或 *ArrayLike, 可选) - 描述设备布局的多维数组/整型张量,ID为默认进程组的全局ID。默认为None
  • mesh_dim_names (tuple[str], 可选) - 为设备布局数组各维度命名的元组,其长度必须与mesh_shape匹配,且每个字符串必须唯一。默认为None

返回值:表示设备布局的DeviceMesh对象

返回类型:DeviceMesh


get_all_groups()

返回所有网格维度的进程组列表。

返回值:一个包含 ProcessGroup 对象的列表。

返回类型:list [torch.distributed.distributed_c10d.ProcessGroup]


get_coordinate()

返回当前秩相对于网格所有维度的相对索引。如果该秩不属于网格,则返回 None。

返回类型:Optional[list [int ]]


get_group(mesh_dim=None)

返回由mesh_dim指定的单个ProcessGroup。如果未指定mesh_dim且DeviceMesh是一维的,则返回该mesh中唯一的ProcessGroup。

参数

  • mesh_dim (str/python:int, 可选) - 可以是mesh维度的名称或索引
  • None. (默认值为) -

返回

一个ProcessGroup对象。

返回类型:ProcessGroup


get_local_rank(mesh_dim=None)

返回给定设备网格维度(mesh_dim)的本地秩。

参数

  • mesh_dim (str/python:int, 可选) - 可以是网格维度的名称或索引
  • None. (网格维度的默认值) -

返回值:表示本地秩的整数值。

返回类型:int

以下程序以SPMD方式在每个进程/秩上运行。本例中,我们使用2台主机,每台主机配备4个GPU。

在秩0、1、2、3上调用mesh_2d.get_local_rank(mesh_dim=0)将返回0;在秩4、5、6、7上调用mesh_2d.get_local_rank(mesh_dim=0)将返回1;在秩0、4上调用mesh_2d.get_local_rank(mesh_dim=1)将返回0;在秩1、5上调用mesh_2d.get_local_rank(mesh_dim=1)将返回1。

在秩2、6上调用mesh_2d.get_local_rank(mesh_dim=1)将返回2;在秩3、7上调用mesh_2d.get_local_rank(mesh_dim=1)将返回3。


示例:

>>> from torch.distributed.device_mesh import DeviceMesh
>>> >
>>> # Initialize device mesh as (2, 4) to represent the topology
>>> # of cross-host(dim 0), and within-host (dim 1).
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])

get_rank()

返回当前全局排名。

返回值类型:int


点对点通信


torch.distributed.send(tensor, dst=None, group=None, tag=0, group_dst=None)

同步发送张量。

警告:NCCL后端不支持tag参数。

参数说明

  • tensor ( Tensor ) - 要发送的张量。
  • dst ( int ) - 全局进程组中的目标rank(不受group参数影响)。目标rank不应与当前进程的rank相同。
  • group (ProcessGroup, 可选) - 要操作的工作进程组。如果为None,将使用默认进程组。
  • tag ( int , 可选) - 用于匹配远程接收操作的标记
  • group_dst ( int , 可选) - 在group中的目标rank。不能同时指定dstgroup_dst参数。

torch.distributed.recv(tensor, src=None, group=None, tag=0, group_src=None)

同步接收一个张量。

警告:NCCL后端不支持tag参数。

参数

  • tensor ( Tensor ) - 用于填充接收数据的张量。
  • src ( int , 可选) - 全局进程组中的源rank(不受group参数影响)。若未指定,将从任意进程接收数据。
  • group (ProcessGroup, 可选) - 要操作的工作进程组。若为None,则使用默认进程组。
  • tag ( int , 可选) - 用于匹配远程发送操作的标签
  • group_src ( int , 可选) - 目标进程在group中的rank。不可同时指定srcgroup_src

返回值:发送方rank

  • 若不属于该进程组,则返回-1

返回类型:int

isend()irecv()

在使用时会返回分布式请求对象。通常不建议手动创建这类对象,因此其具体类型不作规定,但保证支持以下两种方法:

  • is_completed() - 若操作完成则返回True
  • wait() - 阻塞进程直至操作完成

is_completed()方法一旦返回结果,其返回值必定为True。


torch.distributed.isend(tensor, dst=None, group=None, tag=0, group_dst=None)

异步发送张量。

警告:在请求完成前修改 tensor 会导致未定义行为。

警告:NCCL 后端不支持 tag 参数。

与阻塞式的 send 不同,isend 允许 src == dst 排名,即支持向自身发送。

参数

  • tensor (Tensor) – 待发送的张量。
  • dst (int) – 全局进程组中的目标排名(不受 group 参数影响)。
  • group (ProcessGroup, 可选) – 操作的目标进程组。若为 None,则使用默认进程组。
  • tag (int, 可选) – 用于匹配远程 recv 的标记。
  • group_dst (int, 可选)group 中的目标排名。不可同时指定 dstgroup_dst

返回

一个分布式请求对象。若不属于该进程组则返回 None。

返回类型

Optional[Work]


torch.distributed.irecv(tensor, src=None, group=None, tag=0, group_src=None)

异步接收一个张量。

警告:NCCL后端不支持tag参数。

与阻塞式的recv不同,irecv允许src等于dst的rank,即可以从自身接收数据。

参数

  • tensor ( Tensor ) – 用于填充接收数据的张量。
  • src ( int , 可选) – 全局进程组中的源rank(不受group参数影响)。如果未指定,将从任意进程接收数据。
  • group (ProcessGroup, 可选) – 要操作的工作进程组。如果为None,则使用默认进程组。
  • tag ( int , 可选) – 用于匹配远程发送的接收标记
  • group_src ( int , 可选) – 在group中的目标rank。不能同时指定srcgroup_src

返回值:一个分布式请求对象。

如果不在该进程组中,则返回None

返回类型:Optional[Work]


torch.distributed.send_object_list(object_list, dst=None, group=None, device=None, group_dst=None)

同步发送 object_list 中可序列化的对象。

send() 类似,但可以传递 Python 对象。

注意,object_list 中的所有对象必须可序列化才能发送。

参数

  • object_list (List[Any]) – 要发送的输入对象列表。每个对象必须可序列化。接收方必须提供大小相等的列表。
  • dst (int) – 发送 object_list 的目标 rank。目标 rank 基于全局进程组(与 group 参数无关)。
  • group (Optional[ProcessGroup]) – (可选)要操作的进程组。如果为 None,则使用默认进程组。默认为 None
  • device (torch.device, optional) – 如果不为 None,对象会被序列化并转换为张量,发送前移动到 device。默认为 None
  • group_dst (int, optional)group 上的目标 rank。必须指定 dstgroup_dst 之一,但不能同时指定。

返回

None

注意:对于基于 NCCL 的进程组,对象的内部张量表示必须在通信前移动到 GPU 设备。此时使用的设备由 torch.cuda.current_device() 给出,用户需确保通过 torch.cuda.set_device() 设置,使每个 rank 拥有独立的 GPU。

警告send_object_list() 隐式使用 pickle 模块,已知其不安全。恶意构造的 pickle 数据可能在反序列化时执行任意代码。仅对可信数据调用此函数。

警告:使用 GPU 张量调用 send_object_list() 支持不佳且效率低下,因为张量会被序列化,导致 GPU-CPU 传输。建议改用 send()

示例

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes backend is not NCCL
>>> device = torch.device("cpu")
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 2、>>    objects = ["foo", 12, {1: 2}] # any picklable object
>>>     dist.send_object_list(objects, dst=1, device=device)
>>> else:
>>>     objects = [None, None, None]
>>>     dist.recv_object_list(objects, src=0, device=device)
>>> objects
['foo', 12, {1: 2}]

torch.distributed.recv_object_list(object_list, src=None, group=None, device=None, group_src=None)

同步接收object_list中的可序列化对象。

类似于recv(),但可以接收Python对象。

参数

  • object_list (List[Any]) - 用于接收对象的列表。必须提供一个与发送列表大小相等的尺寸列表。
  • src (int, 可选) - 接收object_list的源进程排名。源排名基于全局进程组(无论group参数如何)。如果设置为None,将从任意排名接收。默认为None
  • group (Optional[ProcessGroup]) - (ProcessGroup, 可选): 要操作的进程组。如果为None,将使用默认进程组。默认为None
  • device (torch.device, 可选) - 如果不为None,则在此设备上接收。默认为None
  • group_src (int, 可选) - group上的目标排名。不能同时指定srcgroup_src

返回

发送方排名。如果排名不属于该组,则为-1。如果排名属于该组,object_list将包含来自src排名的发送对象。

注意:对于基于NCCL的进程组,对象的内部张量表示必须在通信之前移动到GPU设备。在这种情况下,使用的设备由torch.cuda.current_device()给出,用户有责任通过torch.cuda.set_device()确保每个排名都有一个单独的GPU。

警告recv_object_list()隐式使用pickle模块,已知其不安全。可能构造恶意的pickle数据,在反序列化期间执行任意代码。仅对可信数据调用此函数。

警告:使用GPU张量调用recv_object_list()不受良好支持且效率低下,因为张量会被pickle,导致GPU-CPU传输。请考虑改用recv()

示例:


>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes backend is not NCCL
>>> device = torch.device("cpu")
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 2、>>    objects = ["foo", 12, {1: 2}] # any picklable object
>>>     dist.send_object_list(objects, dst=1, device=device)
>>> else:
>>>     objects = [None, None, None]
>>>     dist.recv_object_list(objects, src=0, device=device)
>>> objects
['foo', 12, {1: 2}]

torch.distributed.batch_isend_irecv(p2p_op_list)

异步发送或接收一批张量并返回请求列表。

处理 p2p_op_list 中的每个操作,并返回对应的请求。当前支持 NCCL、Gloo 和 UCC 后端。

参数

  • p2p_op_list (list[torch.distributed.distributed_c10d.P2POp]) – 点对点操作列表(每个操作的类型为 torch.distributed.P2POp)。列表中的 isend/irecv 顺序很重要,需要与远程端的对应 isend/irecv 匹配。

返回

通过调用 op_list 中的对应操作返回的分布式请求对象列表。

返回类型

list [torch.distributed.distributed_c10d.Work]


示例:

>>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
>>> recv_tensor = torch.randn(2, dtype=torch.float32)
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
>>> recv_op = dist.P2POp(
...     dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
... )
>>> reqs = batch_isend_irecv([send_op, recv_op])
>>> for req in reqs:
>>>     req.wait()
>>> recv_tensor
tensor([2, 3])     # Rank 0
tensor([0, 1])     # Rank 1

注意:当此API与NCCL PG后端一起使用时,用户必须通过torch.cuda.set_device设置当前GPU设备,否则会导致意外的挂起问题。

此外,如果此API是传入dist.P2POpgroup中的第一个集合通信调用,则该group的所有进程都必须参与此次API调用;否则行为将是未定义的。如果此API调用不是group中的第一个集合通信操作,则允许仅涉及group中部分进程的批量P2P操作。


class torch.distributed.P2POp(op, tensor, peer=None, group=None, tag=0, group_peer=None)

一个用于为batch_isend_irecv构建点对点操作的类。

该类构建P2P操作类型、通信缓冲区、对等节点秩、进程组和标签。此类的实例将被传递给batch_isend_irecv以进行点对点通信。

参数

  • op (Callable) – 用于向对等进程发送或接收数据的函数。

op的类型为torch.distributed.isendtorch.distributed.irecv

  • tensor ( Tensor ) – 要发送或接收的张量。
  • peer ( int , optional) – 目标或源秩。
  • group (ProcessGroup, optional) – 要操作的进程组。如果为None,将使用默认进程组。
  • tag ( int , optional) – 用于匹配发送与接收的标签。
  • group_peer ( int , optional) – 目标或源秩。

同步与异步集合操作

每个集合操作函数都支持以下两种操作模式,具体取决于传入的async_op标志设置:

同步操作 - 默认模式,当async_op设为False时生效。函数返回时,可以确保集合操作已执行完成。对于CUDA操作而言,由于CUDA操作本身是异步的,此时不能保证CUDA操作已完成。对于CPU集合操作,后续使用该操作输出的函数调用将按预期工作。对于CUDA集合操作,在同一个CUDA流中使用输出的函数调用将按预期工作。若在不同流中运行,用户需自行处理同步问题。有关CUDA语义(如流同步)的详细信息,请参阅CUDA语义。下方脚本展示了CPU与CUDA操作在这些语义上的差异示例。

异步操作 - 当async_op设为True时生效。集合操作函数会返回一个分布式请求对象。通常无需手动创建该对象,它保证支持以下方法:

  • is_completed() - 对于CPU集合操作,完成时返回True。对于CUDA操作,当操作成功加入CUDA流且输出可在默认流中使用而无需额外同步时返回True
  • wait() - 对于CPU集合操作,将阻塞进程直至操作完成。对于CUDA集合操作,将阻塞当前活跃的CUDA流直至操作完成(但不会阻塞CPU)
  • get_future() - 返回torch._C.Future对象。支持NCCL后端,也支持GLOO和MPI后端的大多数操作(点对点操作除外)
    注意:随着我们持续采用Future并合并API,get_future()调用可能会变得冗余

示例

以下代码可作为使用分布式集合操作时CUDA操作语义的参考,展示了在不同CUDA流中使用集合操作输出时需要显式同步的情况:

# Code runs on each rank.
dist.init_process_group("nccl", rank=rank, world_size=2)
output = torch.tensor([rank]).cuda(rank)
s = torch.cuda.Stream()
handle = dist.all_reduce(output, async_op=True)
# Wait ensures the operation is enqueued, but not necessarily complete.
handle.wait()
# Using result on non-default stream. with torch.cuda.stream(s):s.wait_stream(torch.cuda.default_stream())output.add_(100) if rank == 0:# if the explicit call to wait_stream was omitted, the output below will be     # non-deterministically 1 or 101, depending on whether the allreduce overwrote# the value after the add completed.print(output)

集合函数


torch.distributed.broadcast(tensor, src=None, group=None, async_op=False, group_src=None)

将张量广播到整个进程组。

所有参与集体通信的进程中,tensor 必须具有相同的元素数量。

参数说明

  • tensor ( Tensor ) - 如果当前进程是源进程(src),则作为待发送数据;否则作为接收数据的存储张量。
  • src ( int ) - 全局进程组中的源进程排名(不受group参数影响)。
  • group (ProcessGroup, 可选) - 操作的进程组。若为None,则使用默认进程组。
  • async_op ([bool], 可选) - 是否作为异步操作执行。
  • group_src ( int ) - 指定group内的源进程排名。必须且只能指定group_srcsrc中的一个。

返回值

  • 若async_op设为True,返回异步操作句柄。
  • 若非异步操作或不属于该进程组,返回None。

torch.distributed.broadcast_object_list(object_list, src=None, group=None, device=None, group_src=None)

object_list 中的可序列化对象广播到整个组。

类似于 broadcast(),但可以传入 Python 对象。

注意,object_list 中的所有对象必须可序列化才能被广播。

参数

  • object_list (List[Any]) – 要广播的输入对象列表。每个对象必须可序列化。只有 src 进程上的对象会被广播,但每个进程必须提供大小相同的列表。
  • src ( int ) – 广播 object_list 的源进程号。源进程号基于全局进程组(与 group 参数无关)。
  • group (Optional[ProcessGroup]) – (可选)要操作的进程组。如果为 None,则使用默认进程组。默认为 None
  • device (torch.device, optional) – 如果非 None,对象会被序列化并转换为张量,广播前移动到 device。默认为 None
  • group_src ( int )group 上的源进程号。不能同时指定 group_srcsrc

返回

None。如果当前进程属于该组,object_list 将包含从 src 进程广播的对象。

注意:对于基于 NCCL 的进程组,对象的内部张量表示必须在通信前移动到 GPU 设备。此时使用的设备由 torch.cuda.current_device() 给出,用户需确保通过 torch.cuda.set_device() 设置每个进程有独立的 GPU。

注意:此 API 与 broadcast() 略有不同,因为它不提供 async_op 句柄,因此是阻塞调用。

警告broadcast_object_list() 隐式使用 pickle 模块,已知其不安全。恶意构造的 pickle 数据可能在反序列化时执行任意代码。仅对可信数据调用此函数。

警告:使用 GPU 张量调用 broadcast_object_list() 支持不佳且效率低下,因为张量会被序列化导致 GPU-CPU 传输。建议改用 broadcast()

示例

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 3、>>    objects = ["foo", 12, {1: 2}] # any picklable object
>>> else:
>>>     objects = [None, None, None]
>>> # Assumes backend is not NCCL
>>> device = torch.device("cpu")
>>> dist.broadcast_object_list(objects, src=0, device=device)
>>> objects
['foo', 12, {1: 2}]

torch.distributed.all_reduce(tensor, op=<RedOpType.SUM: 0>, group=None, async_op=False)

以所有机器都能获取最终结果的方式对张量数据进行归约操作。

调用后,所有进程中的 tensor 将保持二进制级别的一致性。

支持复数张量。

参数

  • tensor (Tensor) - 集合操作的输入和输出张量。该函数会就地修改张量。
  • op (可选) - 从 torch.distributed.ReduceOp 枚举中选择的操作类型。指定用于逐元素归约的运算方式。
  • group (ProcessGroup, 可选) - 要操作的工作进程组。若为 None,则使用默认进程组。
  • async_op (bool, 可选) - 是否将此操作设为异步操作。

返回

  • 若 async_op 设为 True,返回异步操作句柄。
  • 若非异步操作或不属于该进程组,则返回 None。

示例:

>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> device = torch.device(f"cuda:{rank}")
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4, 6], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1

>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor(
...     [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1

torch.distributed.reduce(tensor, dst=None, op=<RedOpType.SUM: 0>, group=None, async_op=False, group_dst=None)

在所有机器间对张量数据进行归约操作。

只有排名为 dst 的进程会接收到最终结果。

参数

  • tensor ( Tensor ) – 集合操作的输入和输出张量。该函数会就地修改数据。
  • dst ( int ) – 全局进程组中的目标排名(不受 group 参数影响)
  • op (可选) – 从 torch.distributed.ReduceOp 枚举中选择的值。指定用于逐元素归约的操作类型。
  • group (ProcessGroup, 可选) – 要操作的目标进程组。若为 None,则使用默认进程组。
  • async_op ([bool], 可选) – 是否将此操作设为异步操作
  • group_dst ( int ) – 在 group 上的目标排名。必须指定 group_dstdst 中的一个,但不能同时指定两者。

返回值:若 async_op 设为 True,则返回异步操作句柄。

若未设置 async_op 或不属于该进程组,则返回 None


torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False)

从整个进程组中收集张量到列表中。

支持复杂且大小不一的张量。

参数

  • tensor_list (list[Tensor]) - 输出列表。该列表应包含正确尺寸的张量,用于集合通信的输出。支持大小不一的张量。
  • tensor (Tensor) - 从当前进程广播的张量。
  • group (ProcessGroup, 可选) - 要操作的进程组。如果为None,则使用默认进程组。
  • async_op ([bool], 可选) - 该操作是否应为异步操作

返回

如果async_op设置为True,则返回异步工作句柄。

如果不设置async_op或不属于该进程组,则返回None


示例:

>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> device = torch.device(f"cuda:{rank}")
>>> tensor_list = [
...     torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
... ]
>>> tensor_list
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0
[tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1

>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [
...     torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
... ]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
>>> tensor = torch.tensor(
...     [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0
[tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1

torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False)

从所有进程收集张量并合并为一个输出张量。

此函数要求每个进程上的所有张量大小相同。

参数

  • output_tensor (Tensor) - 用于容纳来自所有进程张量元素的输出张量。其尺寸必须正确设置为以下形式之一:

(i) 沿主维度拼接所有输入张量;关于"拼接"的定义,请参阅 torch.cat()

(ii) 沿主维度堆叠所有输入张量;关于"堆叠"的定义,请参阅 torch.stack()

下方示例可以更清楚地说明支持的输出形式。

  • input_tensor (Tensor) - 从当前进程收集的输入张量。

all_gather API 不同,本 API 要求所有进程的输入张量必须具有相同大小。

  • group (ProcessGroup, 可选) - 要操作的工作进程组。如果为 None,则使用默认进程组。
  • async_op ([bool], 可选) - 是否将此操作设为异步操作

返回

如果 async_op 设为 True,则返回异步操作句柄。

如果不设 async_op 或不属于该进程组,则返回 None


示例:

>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
>>> # We have two ranks.
>>> device = torch.device(f"cuda:{rank}")
>>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor_in
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>> # Output in concatenation form
>>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device)
>>> dist.all_gather_into_tensor(tensor_out, tensor_in)
>>> tensor_out
tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
>>> # Output in stack form
>>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
>>> dist.all_gather_into_tensor(tensor_out2, tensor_in)
>>> tensor_out2
tensor([[1, 2], [3, 4]], device='cuda:0') # Rank 0
tensor([[1, 2], [3, 4]], device='cuda:1') # Rank 1

警告:Gloo 后端不支持此 API。


torch.distributed.all_gather_object(object_list, obj, group=None)

将整个组中的可pickle对象收集到一个列表中。

类似于 all_gather(),但可以传递Python对象。

注意:对象必须是可pickle的才能被收集。

参数

  • object_list (list[Any]) – 输出列表。其大小应正确设置为该集合操作的组大小,并将包含输出结果。
  • obj (Any) – 从当前进程广播的可pickle的Python对象。
  • group (ProcessGroup, 可选) – 要操作的工作进程组。如果为None,则使用默认进程组。默认为None

返回

无。如果调用rank属于该组,集合操作的输出将填充到输入的object_list中。如果调用rank不属于该组,传入的object_list将保持不变。

注意:请注意此API与 all_gather() 集合操作略有不同,因为它不提供async_op句柄,因此将是一个阻塞调用。

注意:对于基于NCCL的进程组,对象的内部张量表示必须在通信发生前移动到GPU设备。这种情况下,使用的设备由torch.cuda.current_device()给出,用户有责任通过torch.cuda.set_device()确保每个rank都有独立的GPU。

警告all_gather_object() 隐式使用pickle模块,已知该模块不安全。可能构造恶意的pickle数据,在反序列化时执行任意代码。仅对可信数据调用此函数。

警告:使用GPU张量调用 all_gather_object() 支持不佳且效率低下,因为张量需要被pickle会导致GPU-CPU传输。请考虑改用 all_gather()

示例

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3、>>gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]

torch.distributed.gather(tensor, gather_list=None, dst=None, group=None, async_op=False, group_dst=None)

将多个进程中的张量列表收集到单个进程中。

此函数要求每个进程中的所有张量大小必须相同。

参数

  • tensor ( Tensor ) – 输入张量。
  • gather_list (list[Tensor ], 可选) – 用于收集数据的适当大小且尺寸相同的张量列表(默认为None,必须在目标rank上指定)
  • dst ( int , 可选) – 全局进程组中的目标rank(不受group参数影响)。(如果dstgroup_dst均为None,则默认为全局rank 0)
  • group (ProcessGroup, 可选) – 要操作的工作进程组。如果为None,则使用默认进程组。
  • async_op ([bool], 可选) – 此操作是否应为异步操作
  • group_dst ( int , 可选)group中的目标rank。不允许同时指定dstgroup_dst

返回值:如果async_op设置为True,则返回异步工作句柄。

如果未设置async_op或不属于该进程组,则返回None

注意:gather_list中的所有张量必须具有相同的大小。

示例:


>>> # We have 2 process groups, 2 ranks.
>>> tensor_size = 2
>>> device = torch.device(f'cuda:{rank}')
>>> tensor = torch.ones(tensor_size, device=device) + rank
>>> if dist.get_rank() == 0:
>>>     gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)]
>>> else:
>>>     gather_list = None
>>> dist.gather(tensor, gather_list, dst=0)
>>> # Rank 0 gets gathered data.
>>> gather_list
[tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0
None                                                                   # Rank 1

torch.distributed.gather_object(obj, object_gather_list=None, dst=None, group=None, group_dst=None)

从整个进程组中收集可序列化对象到单个进程。

功能类似于 gather(),但支持传递Python对象。注意:待收集的对象必须可序列化。

参数

  • obj (Any) – 输入对象,必须可序列化。
  • object_gather_list (list[Any]) – 输出列表。在目标dst进程上,该列表需预先分配为进程组大小的空间以存储结果。非目标进程上必须设为None(默认值为None)。
  • dst ( int , optional) – 全局进程组中的目标进程编号(不受group参数影响)。若dstgroup_dst均为None,则默认为全局0号进程。
  • group (Optional[ProcessGroup]) – 操作的目标进程组。若为None则使用默认进程组(默认值为None)。
  • group_dst ( int , optional) – 指定group参数对应进程组中的目标进程编号。不可同时指定dstgroup_dst

返回值

无。在目标dst进程上,object_gather_list将包含集合操作的结果。

注意

本API与常规gather操作略有不同:不提供async_op异步句柄,因此是阻塞调用。

注意

对于基于NCCL的进程组,对象内部的张量表示必须在通信前移至GPU设备。此时设备由torch.cuda.current_device()决定,用户需通过torch.cuda.set_device()确保每个进程独占GPU。

警告

gather_object()隐式使用pickle模块,该模块存在安全隐患。恶意构造的pickle数据可能在反序列化时执行任意代码。请仅对可信数据调用此函数。

警告

对GPU张量调用gather_object()支持不佳且效率低下,因为序列化会引发GPU-CPU传输。建议改用gather()


示例:

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3、>>gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
...     gather_objects[dist.get_rank()], 
...     output if dist.get_rank() == 0 else None, 
...     dst=0
... )
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]

torch.distributed.scatter(tensor, scatter_list=None, src=None, group=None, async_op=False, group_src=None)

将一组张量分散到进程组中的所有进程。

每个进程将准确接收一个张量,并将其数据存储在 tensor 参数中。

支持复数张量。

参数

  • tensor ( Tensor ) – 输出张量。
  • scatter_list (list[Tensor ]) – 要分散的张量列表(默认为 None,必须在源 rank 上指定)
  • src ( int ) – 全局进程组中的源 rank(不受 group 参数影响)。

(如果 srcgroup_src 均为 None,则默认为全局 rank 0)

  • group (ProcessGroup, 可选) – 要操作的进程组。如果为 None,则使用默认进程组。
  • async_op ([bool], 可选) – 此操作是否应为异步操作
  • group_src ( int , 可选)group 中的源 rank。不能同时指定 srcgroup_src

返回

如果 async_op 设置为 True,则返回异步工作句柄。

如果不为 async_op 或不属于该组,则返回 None

注意:请注意,scatter_list 中的所有张量必须具有相同的大小。

示例

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> tensor_size = 2
>>> device = torch.device(f'cuda:{rank}')
>>> output_tensor = torch.zeros(tensor_size, device=device)
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 2、>>    # Only tensors, all of which must be the same size.
>>>     t_ones = torch.ones(tensor_size, device=device)
>>>     t_fives = torch.ones(tensor_size, device=device) * 5
>>>     scatter_list = [t_ones, t_fives]
>>> else:
>>>     scatter_list = None
>>> dist.scatter(output_tensor, scatter_list, src=0)
>>> # Rank i gets scatter_list[i].
>>> output_tensor
tensor([1., 1.], device='cuda:0') # Rank 0
tensor([5., 5.], device='cuda:1') # Rank 1

torch.distributed.scatter_object_list(scatter_object_output_list, scatter_object_input_list=None, src=None, group=None, group_src=None)

scatter_object_input_list 中的可序列化对象分发到整个组中。

类似于 scatter(),但可以传递 Python 对象。在每个 rank 上,分发的对象将作为 scatter_object_output_list 的第一个元素存储。注意,scatter_object_input_list 中的所有对象必须可序列化才能被分发。

参数

  • scatter_object_output_list (List[Any]) – 非空列表,其第一个元素将存储分发到当前 rank 的对象。
  • scatter_object_input_list (List[Any], optional) – 要分发的输入对象列表。每个对象必须可序列化。只有 src rank 上的对象会被分发,非 src rank 可以传入 None
  • src ( int ) – 分发 scatter_object_input_list 的源 rank。源 rank 基于全局进程组(与 group 参数无关)。(如果 srcgroup_src 均为 None,则默认为全局 rank 0)
  • group (Optional[ProcessGroup]) – (ProcessGroup,可选):要操作的进程组。如果为 None,则使用默认进程组。默认为 None
  • group_src ( int , optional)group 上的源 rank。不能同时指定 srcgroup_src

返回值

None。如果当前 rank 属于该组,scatter_object_output_list 的第一个元素将被设置为分发到该 rank 的对象。

注意:请注意此 API 与 scatter 集合操作略有不同,因为它不提供 async_op 句柄,因此是一个阻塞调用。

警告scatter_object_list() 隐式使用了 pickle 模块,已知该模块不安全。可能构造恶意的 pickle 数据,在反序列化时执行任意代码。请仅对可信数据调用此函数。

警告:使用 GPU 张量调用 scatter_object_list() 支持不佳且效率低下,因为张量需要序列化会导致 GPU-CPU 传输。请考虑改用 scatter()

示例

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 3、>>    objects = ["foo", 12, {1: 2}] # any picklable object
>>> else:
>>>     # Can be any list on non-src ranks, elements are not used.
>>>     objects = [None, None, None]
>>> output_list = [None]
>>> dist.scatter_object_list(output_list, objects, src=0)
>>> # Rank i gets objects[i]. For example, on rank 2:
>>> output_list
[{1: 2}]

torch.distributed.reduce_scatter(output, input_list, op=<RedOpType.SUM: 0>, group=None, async_op=False)

将一组张量进行归约后分散到进程组中的所有进程。

参数

  • output ( Tensor ) – 输出张量。
  • input_list (list[Tensor ]) – 待归约和分散的张量列表。
  • op (可选) – 从 torch.distributed.ReduceOp 枚举中选择的值。指定用于逐元素归约的操作。
  • group (ProcessGroup, 可选) – 要操作的进程组。如果为 None,则使用默认进程组。
  • async_op ([bool], 可选) – 此操作是否应为异步操作。

返回值:如果 async_op 设为 True,则返回异步工作句柄。

如果不为异步操作或不属于该进程组,则返回 None。


torch.distributed.reduce_scatter_tensor(output, input, op=<RedOpType.SUM: 0>, group=None, async_op=False)

对张量进行归约操作后,将其分散到组内所有进程中。

参数

  • output (Tensor) - 输出张量。所有进程中的该张量应保持相同大小。
  • input (Tensor) - 待归约和分散的输入张量。其大小应为输出张量大小乘以进程组规模。输入张量可具有以下两种形状之一:

(i) 沿主维度拼接的输出张量序列,或

(ii) 沿主维度堆叠的输出张量序列。

关于"拼接"的定义,请参阅 torch.cat()

关于"堆叠"的定义,请参阅 torch.stack()

  • group (ProcessGroup, 可选) - 要操作的进程组。若为None,则使用默认进程组。
  • async_op (bool, 可选) - 是否将此操作设为异步操作。

返回

若 async_op 设为 True,返回异步工作句柄。

若未设置 async_op 或不属于该进程组,返回 None。


示例:

>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
>>> # We have two ranks.
>>> device = torch.device(f"cuda:{rank}")
>>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
>>> # Input in concatenation form
>>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
>>> tensor_in
tensor([0, 1, 2, 3], device='cuda:0') # Rank 0
tensor([0, 1, 2, 3], device='cuda:1') # Rank 1
>>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
>>> tensor_out
tensor([0, 2], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1
>>> # Input in stack form
>>> tensor_in = torch.reshape(tensor_in, (world_size, 2))
>>> tensor_in
tensor([[0, 1], [2, 3]], device='cuda:0') # Rank 0
tensor([[0, 1], [2, 3]], device='cuda:1') # Rank 1
>>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
>>> tensor_out
tensor([0, 2], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1

警告:Gloo 后端不支持此 API。


torch.distributed.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False)

将输入张量分割后分散到组内所有进程中。

随后从组内所有进程接收到的张量会被拼接起来,作为单个输出张量返回。

支持复数张量。

参数

  • output ( Tensor ) – 收集拼接后的输出张量。
  • input ( Tensor ) – 待分散的输入张量。
  • output_split_sizes – (list[Int], 可选): 如果指定为None或空列表,则要求output张量的第0维必须能被world_size整除;否则指定第0维的输出分割尺寸。
  • input_split_sizes – (list[Int], 可选): 如果指定为None或空列表,则要求input张量的第0维必须能被world_size整除;否则指定第0维的输入分割尺寸。
  • group (ProcessGroup, 可选) – 要操作的工作进程组。如果为None,则使用默认进程组。
  • async_op ([bool], 可选) – 是否将此操作设为异步操作。

返回值:如果async_op设为True,则返回异步操作句柄。

如果不设async_op或不属于该进程组,则返回None。

警告:all_to_all_single是实验性功能,后续可能变更。

示例


>>> input = torch.arange(4) + rank * 4
>>> input
tensor([0, 1, 2, 3])     # Rank 0
tensor([4, 5, 6, 7])     # Rank 1
tensor([8, 9, 10, 11])   # Rank 2
tensor([12, 13, 14, 15]) # Rank 3
>>> output = torch.empty([4], dtype=torch.int64)
>>> dist.all_to_all_single(output, input)
>>> output
tensor([0, 4, 8, 12])    # Rank 0
tensor([1, 5, 9, 13])    # Rank 1
tensor([2, 6, 10, 14])   # Rank 2
tensor([3, 7, 11, 15])   # Rank 3

>>> # Essentially, it is similar to following operation:
>>> scatter_list = list(input.chunk(world_size))
>>> gather_list = list(output.chunk(world_size))
>>> for i in range(world_size):
>>>     dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)

>>> # Another example with uneven split
>>> input
tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
tensor([20, 21, 22, 23, 24])                                     # Rank 2
tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
>>> input_splits
[2, 2, 1, 1]                                                     # Rank 0
[3, 2, 2, 2]                                                     # Rank 1
[2, 1, 1, 1]                                                     # Rank 2
[2, 2, 2, 1]                                                     # Rank 3
>>> output_splits
[2, 3, 2, 2]                                                     # Rank 0
[2, 2, 1, 2]                                                     # Rank 1
[1, 2, 1, 2]                                                     # Rank 2
[1, 2, 1, 1]                                                     # Rank 3
>>> output = ...
>>> dist.all_to_all_single(output, input, output_splits, input_splits)
>>> output
tensor([0, 1, 10, 11, 12, 20, 21, 30, 31])                     # Rank 0
tensor([2, 3, 13, 14, 22, 32, 33])                             # Rank 1
tensor([4, 15, 16, 23, 34, 35])                                 # Rank 2
tensor([5, 17, 18, 24, 36])                                     # Rank 3

>>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor(
...     [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>> input
tensor([1+1j, 2+2j, 3+3j, 4+4j])                                # Rank 0
tensor([5+5j, 6+6j, 7+7j, 8+8j])                                # Rank 1
tensor([9+9j, 10+10j, 11+11j, 12+12j])                          # Rank 2
tensor([13+13j, 14+14j, 15+15j, 16+16j])                        # Rank 3
>>> output = torch.empty([4], dtype=torch.int64)
>>> dist.all_to_all_single(output, input)
>>> output
tensor([1+1j, 5+5j, 9+9j, 13+13j])                              # Rank 0
tensor([2+2j, 6+6j, 10+10j, 14+14j])                            # Rank 1
tensor([3+3j, 7+7j, 11+11j, 15+15j])                            # Rank 2
tensor([4+4j, 8+8j, 12+12j, 16+16j])                            # Rank 3

torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False)

将输入张量列表分散到组内所有进程,并返回聚合后的输出张量列表。

支持复数张量。

参数

  • output_tensor_list (list[Tensor]) - 每个rank待聚合的张量列表。
  • input_tensor_list (list[Tensor]) - 每个rank待分散的张量列表。
  • group (ProcessGroup, 可选) - 操作的工作进程组。若为None,则使用默认进程组。
  • async_op (bool, 可选) - 是否将操作设为异步模式。

返回值

  • 若async_op设为True,返回异步操作句柄。
  • 若非异步模式或不属于该进程组,则返回None。

警告:all_to_all接口处于实验阶段,后续可能变更。


示例:

>>> input = torch.arange(4) + rank * 4
>>> input = list(input.chunk(4))
>>> input
[tensor([0]), tensor([1]), tensor([2]), tensor([3])]     # Rank 0
[tensor([4]), tensor([5]), tensor([6]), tensor([7])]     # Rank 1
[tensor([8]), tensor([9]), tensor([10]), tensor([11])]   # Rank 2
[tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
>>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>> dist.all_to_all(output, input)
>>> output
[tensor([0]), tensor([4]), tensor([8]), tensor([12])]    # Rank 0
[tensor([1]), tensor([5]), tensor([9]), tensor([13])]    # Rank 1
[tensor([2]), tensor([6]), tensor([10]), tensor([14])]   # Rank 2
[tensor([3]), tensor([7]), tensor([11]), tensor([15])]   # Rank 3

>>> # Essentially, it is similar to following operation:
>>> scatter_list = input
>>> gather_list = output
>>> for i in range(world_size):
>>>     dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)

>>> input
tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
tensor([20, 21, 22, 23, 24])                                     # Rank 2
tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
>>> input_splits
[2, 2, 1, 1]                                                     # Rank 0
[3, 2, 2, 2]                                                     # Rank 1
[2, 1, 1, 1]                                                     # Rank 2
[2, 2, 2, 1]                                                     # Rank 3
>>> output_splits
[2, 3, 2, 2]                                                     # Rank 0
[2, 2, 1, 2]                                                     # Rank 1
[1, 2, 1, 2]                                                     # Rank 2
[1, 2, 1, 1]                                                     # Rank 3
>>> input = list(input.split(input_splits))
>>> input
[tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])]                   # Rank 0
[tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
[tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])]                 # Rank 2
[tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])]         # Rank 3
>>> output = ...
>>> dist.all_to_all(output, input)
>>> output
[tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])]   # Rank 0
[tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])]           # Rank 1
[tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])]              # Rank 2
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])]                  # Rank 3

>>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor(
...     [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>> input = list(input.chunk(4))
>>> input
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])]            # Rank 0
[tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])]            # Rank 1
[tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])]      # Rank 2
[tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])]    # Rank 3
>>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>> dist.all_to_all(output, input)
>>> output
[tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])]          # Rank 0
[tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])]        # Rank 1
[tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])]        # Rank 2
[tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])]        # Rank 3

torch.distributed.barrier(group=None, async_op=False, device_ids=None)

同步所有进程。

如果 async_op 为 False,或者对 wait() 调用了异步工作句柄,该集合操作会阻塞进程,直到整个组进入此函数。

参数

  • group (ProcessGroup, 可选) – 要操作的进程组。如果为 None,则使用默认进程组。
  • async_op ([bool], 可选) – 该操作是否为异步操作
  • device_ids ([int], 可选) – 设备/GPU ID 列表。

返回值:如果 async_op 设为 True,返回异步工作句柄。

如果不为 async_op 或不属于该组,返回 None。

注意:ProcessGroupNCCL 现在会阻塞 CPU 线程,直到屏障集合操作完成。


torch.distributed.monitored_barrier(group=None, timeout=None, wait_all_ranks=False)

实现类似torch.distributed.barrier的进程同步功能,但支持可配置的超时机制。

该机制能够报告在指定超时时间内未能通过屏障的进程排名(ranks)。

具体而言:

  • 对于非0排名进程,会阻塞直至完成与rank 0的发送/接收操作
  • Rank 0进程会阻塞直至处理完所有其他进程的发送/接收操作,并上报超时未响应的进程排名
  • 注意:若任一进程未到达monitored_barrier(例如因挂起),所有其他进程都会在monitored_barrier处失败

这个集合操作会阻塞组内所有进程/排名,直到整个组成功退出该函数,因此非常适用于调试和同步场景。但需注意其性能开销,建议仅用于调试或需要主机端完全同步点的场景。调试时可在应用程序的集合调用前插入此屏障,用于检查是否存在进程不同步的情况。

注意:该集合操作仅支持GLOO后端。

参数说明

  • group (ProcessGroup, 可选) - 要操作的工作进程组。若为None则使用默认进程组
  • timeout ([datetime.timedelta, 可选) - monitored_barrier的超时时间。若为None则使用默认进程组超时设置
  • wait_all_ranks ([bool], 可选) - 是否收集所有失败进程排名。默认为False,此时rank 0上的monitored_barrier会在遇到第一个失败排名时立即抛出异常以实现快速失败。若设为True则会收集所有失败排名并抛出包含全部失败信息的错误

返回值

None

使用示例


>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() != 1:
>>>     dist.monitored_barrier() # Raises exception indicating that >># rank 1 did not call into monitored_barrier.
>>> # Example with wait_all_ranks=True
>>> if dist.get_rank() == 0:
>>>     dist.monitored_barrier(wait_all_ranks=True) # Raises exception
>>> # indicating that ranks 1, 2, 
... world_size - 1 did not call into
>>> # monitored_barrier.

class torch.distributed.Work 

Work对象代表PyTorch分布式包中一个待处理的异步操作句柄。它由非阻塞的集合操作返回,例如dist.all_reduce(tensor, async_op=True)

boxed(self: torch._C._distributed_c10d.Work) → object

exception(self: torch._C._distributed_c10d.Work) → std::__exception_ptr::exception_ptr

get_future(self: torch._C._distributed_c10d.Work) → torch.Future

返回值:一个与Work完成相关联的torch.futures.Future对象。例如,可以通过fut = process_group.allreduce(tensors).get_future()获取future对象。

示例:下面是一个简单的allreduce DDP通信钩子示例,它使用get_future API来检索与allreduce完成相关联的Future。


>>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -torch.futures.Future>>>     group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD>>>     tensor = bucket.buffer().div_(group_to_use.size())>>>     return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future()>>> ddp_model.register_comm_hook(state=None, hook=allreduce)

警告:get_future API 支持 NCCL 后端,部分支持 GLOO 和 MPI 后端(不支持点对点操作如 send/recv),并将返回一个 torch.futures.Future

在上述示例中,allreduce 操作将通过 NCCL 后端在 GPU 上执行。fut.wait() 会在 NCCL 流与 PyTorch 当前设备流同步后返回,以确保支持异步 CUDA 执行,而无需等待整个 GPU 操作完成。请注意,CUDAFuture 不支持 TORCH_NCCL_BLOCKING_WAIT 标志或 NCCL 的 barrier() 功能。

此外,若通过 fut.then() 添加了回调函数,该回调将等待 WorkNCCL 的 NCCL 流与 ProcessGroupNCCL 的专用回调流同步,并在回调流上执行后立即触发回调。fut.then() 会返回另一个 CUDAFuture,其中包含回调函数的返回值以及记录回调流的 CUDAEvent

1、对于 CPU 任务,fut.done() 在任务完成且 value() 张量就绪时返回 true。
2、对于 GPU 任务,fut.done() 仅在操作已加入队列时返回 true。
3、对于 CPU-GPU 混合任务(例如通过 GLOO 发送 GPU 张量),fut.done() 在张量到达目标节点时返回 true,但 GPU 上的同步可能尚未完成(与纯 GPU 任务类似)。


get_future_result(self: torch._C._distributed_c10d.Work) → torch.Future

返回

一个torch.futures.Future类型的对象,其整数值对应WorkResult枚举类型

例如,可以通过fut = process_group.allreduce(tensor).get_future_result()获取future对象。


示例:用户可以使用fut.wait()阻塞等待工作完成,并通过fut.value()获取WorkResult。

此外,用户还可以使用fut.then(call_back_func)注册回调函数,

该函数会在工作完成时被调用,且不会阻塞当前线程。

警告:get_future_result API仅支持NCCL


is_completed(self: torch._C._distributed_c10d.Work)bool

is_success(self: torch._C._distributed_c10d.Work)bool  

result(self: torch._C._distributed_c10d.Work)list [torch.Tensor]

获取工作对象的结果,返回一个包含torch.Tensor的列表


source_rank(self: torch._C._distributed_c10d.Work)int

获取发送该工作对象的源进程排名,返回一个整数值


synchronize(self: torch._C._distributed_c10d.Work)None  

static unbox(arg0:  object ) → torch._C._distributed_c10d.Work 

wait(self: torch._C._distributed_c10d.Work, timeout: [datetime.timedelta = datetime.timedelta(0))bool  

返回值 : true/false。

示例::


try:work.wait(timeout)
except:# some handling

警告:通常情况下,用户无需设置超时参数。

调用 wait() 等同于调用 synchronize():

会使当前流阻塞直至 NCCL 工作完成。

但如果设置了超时参数,则会阻塞 CPU 线程直至 NCCL 工作完成或超时。若发生超时,将抛出异常。


class torch.distributed.ReduceOp 

一个枚举类,用于表示可用的归约操作:SUM(求和)、PRODUCT(乘积)、MIN(最小值)、MAX(最大值)、BAND(按位与)、BOR(按位或)、BXOR(按位异或)以及PREMUL_SUM(预乘求和)。

注意事项:

  • 当使用NCCL后端时,BANDBORBXOR归约操作不可用。
  • AVG(平均值)会在跨节点求和前将数值除以全局进程数。该操作仅支持NCCL后端,且要求NCCL版本为2.10及以上。
  • PREMUL_SUM会在归约前将输入张量乘以指定的标量。该操作仅支持NCCL后端,且要求NCCL版本为2.11及以上。用户应使用torch.distributed._make_nccl_premul_sum来调用。
  • 复数张量不支持MAXMINPRODUCT操作。

使用方式:

  • 可通过属性访问枚举值,例如ReduceOp.SUM
  • 用于指定集合通信的归约策略,例如reduce()

限制说明:

  • 本类不支持__members__属性

class torch.distributed.reduce_op

已弃用的枚举式类,用于定义归约操作:SUM(求和)、PRODUCT(乘积)、MIN(最小值)和MAX(最大值)。

建议改用 ReduceOp 类。


分布式键值存储

分布式包内置了一个分布式键值存储,可用于在进程组之间共享信息,也可用于初始化分布式包(通过显式创建存储作为指定 init_method 的替代方案)。键值存储有三种选择:TCPStoreFileStoreHashStore


class torch.distributed.Store 
Base class for all store implementations, such as the 3 provided by PyTorch
distributed: ([`TCPStore`](https://pytorch.org/docs/stable/data.html#torch.distributed.TCPStore "torch.distributed.TCPStore"), [`FileStore`](https://pytorch.org/docs/stable/data.html#torch.distributed.FileStore "torch.distributed.FileStore"), and [`HashStore`](https://pytorch.org/docs/stable/data.html#torch.distributed.HashStore "torch.distributed.HashStore")).__init__(self: torch._C._distributed_c10d.Store)None  

add(self: torch._C._distributed_c10d.Store, arg0:  str , arg1:  int )int

首次对某个 key 调用 add 方法时,会在存储中创建一个与该 key 关联的计数器,并初始化为 amount 值。后续对相同 key 调用 add 方法时,计数器会按指定的 amount 值递增。

若调用 add() 时指定的 key 已被 set() 方法设置过,则会抛出异常。

参数

  • key (str) – 存储中待递增计数器的键名
  • amount ( int ) – 计数器递增的数值量

示例:

>>> import torch.distributed as dist>>> from datetime import timedelta>>> # 以TCPStore为例,其他存储类型也可使用
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))>>> store.add("first_key", 1)>>> store.add("first_key", 6)>>> # 应返回7
>>> store.get("first_key")

append(self: torch._C._distributed_c10d.Store, arg0:  str , arg1:  str )None  

根据提供的 keyvalue 将键值对追加到存储中。如果存储中不存在该 key,则会自动创建。

参数

  • key (str) – 要追加到存储中的键名。
  • value (str) – 与 key 关联并添加到存储中的值。

示例:

>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.append("first_key", "po")
>>> store.append("first_key", "tato")
>>> # Should return "potato"
>>> store.get("first_key")

check(self: torch._C._distributed_c10d.Store, arg0: list[str])bool

检查给定keys列表是否在存储中有值的调用。该调用在正常情况下会立即返回,但仍可能遇到某些边缘死锁情况,例如在TCPStore已被销毁后调用检查。

调用check()时传入需要检查是否存在于存储中的键列表。

参数

  • keys (list[str]) – 需要查询是否存在于存储中的键列表。

示例:


>>> import torch.distributed as dist>>> from datetime import timedelta>>> # 以TCPStore为例,其他存储类型也可使用
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))>>> store.add("first_key", 1)>>> # 应返回7
>>> store.check(["first_key"])

compare_set(self: torch._C._distributed_c10d.Store, arg0:  str , arg1:  str , arg2:  str ) → bytes  

根据提供的 key 将键值对插入存储,并在插入前对 expected_valuedesired_value 进行比较。只有当该 key 对应的 expected_value 已存在于存储中,或 expected_value 为空字符串时,才会设置 desired_value

参数

  • key (str) – 需要在存储中检查的键名。
  • expected_value (str) – 插入前需检查的、与 key 关联的预期值。
  • desired_value (str) – 需要添加到存储中、与 key 关联的目标值。

示例:

>>> import torch.distributed as dist>>> from datetime import timedelta>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))>>> store.set("key", "first_value")>>> store.compare_set("key", "first_value", "second_value")>>> # 应返回 "second_value">>> store.get("key")


delete_key(self: torch._C._distributed_c10d.Store, arg0:  str ) → bool  

从存储中删除与key关联的键值对。如果键成功删除则返回true,否则返回false。

警告:delete_key API仅支持TCPStoreHashStore。在FileStore上使用此API会引发异常。

参数

  • key (str) - 要从存储中删除的键

返回值:如果key被删除则返回True,否则返回False。


示例:

>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, HashStore can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key")
>>> # This should return true
>>> store.delete_key("first_key")
>>> # This should return false
>>> store.delete_key("bad_key")

get(self: torch._C._distributed_c10d.Store, arg0:  str )bytes

从存储中获取与给定key关联的值。如果key不存在于存储中,该函数将等待初始化存储时定义的timeout时长,然后抛出异常。


参数

  • key (str) – 函数将返回与此键关联的值。

返回值:如果key存在于存储中,则返回与之关联的值。


示例:


>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # Should return "first_value"
>>> store.get("first_key")

has_extended_api(self: torch._C._distributed_c10d.Store)bool 

如果存储支持扩展操作,则返回 true。

multi_get(self: torch._C._distributed_c10d.Store, arg0: list [str ]) → list [bytes ]

获取 keys 中的所有值。如果 keys 中的任意键不存在于存储中,该函数将等待 timeout

参数

  • keys (List[str]) – 要从存储中获取的键列表。

示例:

>>> import torch.distributed as dist>>> from datetime import timedelta>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))>>> store.set("first_key", "po")>>> store.set("second_key", "tato")>>> # 应返回 [b"po", b"tato"]>>> store.multi_get(["first_key", "second_key"])

multi_set(self: torch._C._distributed_c10d.Store, arg0:  list  [str ], arg1:  list  [str ])None  

根据提供的 keysvalues 向存储中插入一个键值对列表

参数

  • keys (List[str]) – 要插入的键列表
  • values (List[str]) – 要插入的值列表

示例:

>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.multi_set(["first_key", "second_key"], ["po", "tato"])
>>> # Should return b"po"
>>> store.get("first_key")

num_keys(self: torch._C._distributed_c10d.Store)int 

返回存储中设置的键数量。需要注意的是,这个数字通常会比通过set()add()方法添加的键数量多1,因为其中一个键用于协调所有使用该存储的工作进程。

警告:当与TCPStore一起使用时,num_keys返回的是写入底层文件的键数量。如果存储被销毁后,另一个存储使用同一文件创建,原有的键仍会被保留。

返回值:存储中当前存在的键数量。


示例:

>>> import torch.distributed as dist>>> from datetime import timedelta>>> # 以TCPStore为例,也可以使用其他存储类型
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))>>> store.set("first_key", "first_value")>>> # 这里应该返回2
>>> store.num_keys()


set(self: torch._C._distributed_c10d.Store, arg0:  str , arg1:  str ) → None  

根据提供的 keyvalue 将键值对插入存储中。如果 key 已存在于存储中,则会用新提供的 value 覆盖旧值。

参数

  • key (str) – 要添加到存储中的键。
  • value (str) – 与 key 关联并要添加到存储中的值。

示例:


>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # Should return "first_value"
>>> store.get("first_key")

set_timeout(self: torch._C._distributed_c10d.Store, arg0: [datetime.timedelta)None 

设置存储的默认超时时间。该超时时间会在初始化期间以及在 wait()get() 方法中使用。

参数

  • timeout (timedelta) – 要设置到存储中的超时时间。

示例:

>>> import torch.distributed as dist>>> from datetime import timedelta>>> # 以TCPStore为例,也可以使用其他存储类型
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))>>> store.set_timeout(timedelta(seconds=10))>>> # 10秒后将抛出异常
>>> store.wait(["bad_key"])

property timeout

获取存储的超时设置。

wait(*args, **kwargs)

这是一个重载函数。

1、wait(self: torch._C._distributed_c10d.Store, arg0: list[str]) -None

等待keys列表中的每个键被添加到存储中。如果在timeout(存储初始化时设置)之前未设置所有键,则wait将抛出异常。


参数

  • keys (list) – 需要等待的键列表,直到它们在存储中被设置。

示例:


>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> # This will throw an exception after 30 seconds
>>> store.wait(["bad_key"])

2、wait(self: torch._C._distributed_c10d.Store, arg0: list[str], arg1: datetime.timedelta) -None

等待keys中的每个键被添加到存储中,如果在指定的timeout时间内这些键未被设置,则抛出异常。


参数说明

  • keys (list) – 需要等待其被设置到存储中的键列表。
  • timeout (timedelta) – 在抛出异常前等待键被添加的最长时间。

使用示例:


>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> # This will throw an exception after 10 seconds
>>> store.wait(["bad_key"], timedelta(seconds=10))

class torch.distributed.TCPStore 

基于TCP协议的分布式键值存储实现。服务器端存储数据,客户端存储可以通过TCP连接到服务器存储,并执行诸如set()插入键值对、get()获取键值对等操作。必须始终初始化一个服务器存储,因为客户端存储会等待服务器建立连接。

参数

  • host_name (str) – 服务器存储应运行的主机名或IP地址。
  • port (int) – 服务器存储监听传入请求的端口号。
  • world_size (int, 可选) – 存储用户总数(客户端数量 + 1个服务器)。默认为None(None表示存储用户数量不固定)。
  • is_master ([bool], 可选) – 初始化服务器存储时为True,客户端存储时为False。默认为False。
  • timeout (timedelta, 可选) – 存储初始化及get()wait()等方法使用的超时时间。默认为timedelta(seconds=300)。
  • wait_for_workers ([bool], 可选) – 是否等待所有工作节点与服务器存储建立连接。仅当world_size为固定值时适用。默认为True。
  • multi_tenant ([bool], 可选) – 若为True,当前进程中具有相同host/port的所有TCPStore实例将共享同一个底层TCPServer。默认为False。
  • master_listen_fd (int, 可选) – 若指定,底层TCPServer将监听此文件描述符(必须为已绑定到port的套接字)。适用于避免某些场景下的端口分配竞争。默认为None(表示服务器创建新套接字并尝试绑定到port)。
  • use_libuv ([bool], 可选) – 若为True,使用libuv作为TCPServer后端。默认为True。

示例:

>>> import torch.distributed as dist>>> from datetime import timedelta>>> # 在进程1(服务端)运行
>>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))>>> # 在进程2(客户端)运行
>>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)>>> # 初始化后,客户端或服务端均可使用存储方法
>>> server_store.set("first_key", "first_value")>>> client_store.get("first_key")


__init__(self: [torch._C._distributed_c10d.TCPStore](https://pytorch.org/docs/stable/data.html#torch.distributed.TCPStore "torch._C._distributed_c10d.TCPStore"), host_name:  str , port:  int , world_size: Optional[int ] = None, is_master:  bool  = False, timeout: [datetime.timedelta = datetime.timedelta(seconds=300), wait_for_workers:  bool  = True, multi_tenant:  bool  = False, master_listen_fd: Optional[int ] = None, use_libuv:  bool  = True) → None  

创建一个新的 TCPStore。


property host 

获取存储服务监听请求的主机名。


property libuvBackend

返回 True 表示当前正在使用 libuv 后端。


property port 

获取存储服务监听请求的端口号。


class torch.distributed.HashStore

一个基于底层哈希映射的线程安全存储实现。该存储可以在同一进程内使用(例如被其他线程使用),但不能跨进程使用。

示例:


>>> import torch.distributed as dist
>>> store = dist.HashStore()
>>> # store can be used from other threads
>>> # Use any of the store methods after initialization
>>> store.set("first_key", "first_value")

__init__(self: [torch._C._distributed_c10d.HashStore](https://pytorch.org/docs/stable/data.html#torch.distributed.HashStore "torch._C._distributed_c10d.HashStore")) → None

创建一个新的 HashStore。


class torch.distributed.FileStore

一个使用文件存储底层键值对的存储实现。


参数

  • file_name (str) – 用于存储键值对的文件路径
  • world_size ( int , 可选) – 使用该存储的进程总数。默认为-1(负值表示存储用户数量不固定)。

示例:


>>> import torch.distributed as dist
>>> store1 = dist.FileStore("/tmp/filestore", 2)
>>> store2 = dist.FileStore("/tmp/filestore", 2)
>>> # Use any of the store methods from either the client or server after initialization
>>> store1.set("first_key", "first_value")
>>> store2.get("first_key")

__init__(self: torch._C._distributed_c10d.FileStore, file_name: str, world_size: int = -1)None

创建一个新的 FileStore。


property path

获取FileStore用于存储键值对的文件路径。


class torch.distributed.PrefixStore 

对三种键值存储(TCPStoreFileStoreHashStore)的封装器,会在每个存入存储的键前添加前缀。

参数

  • prefix (str) - 在键存入存储前添加的前缀字符串。
  • store (torch.distributed.store) - 作为底层键值存储的存储对象。

__init__(self: torch._C._distributed_c10d.PrefixStore, prefix:  str , store: torch._C._distributed_c10d.Store)None  

创建一个新的 PrefixStore。


property underlying_store 

获取 PrefixStore 所封装的基础存储对象。


分析集体通信性能

请注意,您可以使用 torch.profiler(推荐使用,仅1.8.1版本后可用)或 torch.autograd.profiler 来分析本文提到的集体通信和点对点通信API。所有开箱即用的后端(glooncclmpi)都支持性能分析,集体通信的使用情况将在分析输出/跟踪中按预期呈现。分析代码的方式与常规的torch运算符完全相同:

import torchimport torch.distributed as dist 
with torch.profiler():tensor = torch.randn(20, 10)dist.all_reduce(tensor)

请参阅 性能分析器文档 以获取性能分析器功能的完整概述。


多GPU集合函数

警告:多GPU函数(指每个CPU线程对应多个GPU)已被弃用。目前,PyTorch分布式推荐采用每个线程对应一个设备的编程模型,本文档中的API即体现了这一模式。如果您是后端开发者且需要支持每个线程管理多个设备,请联系PyTorch分布式维护团队。


第三方后端

除了内置的 GLOO/MPI/NCCL 后端外,PyTorch 分布式模块通过运行时注册机制支持第三方后端。关于如何通过 C++ 扩展开发第三方后端的参考文档,请查阅 教程 - 自定义 C++ 和 CUDA 扩展 以及 test/cpp_extensions/cpp_c10d_extension.cpp。第三方后端的功能由其自身实现决定。

新后端需要继承自 c10d::ProcessGroup,并在导入时通过 torch.distributed.Backend.register_backend() 注册后端名称和实例化接口。

当手动导入该后端并通过指定后端名称调用 torch.distributed.init_process_group() 时,torch.distributed 包将运行在新的后端上。


警告:第三方后端支持目前处于实验阶段,后续可能发生变更。


启动工具

torch.distributed 包还在 torch.distributed.launch 中提供了一个启动工具。这个辅助工具可用于在每个节点上启动多个进程进行分布式训练。

模块 torch.distributed.launch

torch.distributed.launch 是一个模块,可在每个训练节点上生成多个分布式训练进程。


警告:该模块将被 torchrun 取代。

该工具可用于单节点分布式训练,其中每个节点会生成一个或多个进程。该工具既可用于 CPU 训练,也可用于 GPU 训练。如果用于 GPU 训练,每个分布式进程将在单个 GPU 上运行。这可以显著提升单节点训练性能。它也可用于多节点分布式训练,通过在每个节点上生成多个进程,同样显著提升多节点分布式训练性能。这对于具有多个支持直接 GPU 的 Infiniband 接口的系统尤其有益,因为所有这些接口都可以用于聚合通信带宽。

无论是单节点分布式训练还是多节点分布式训练,该工具都会在每个节点上启动指定数量的进程(--nproc-per-node)。如果用于 GPU 训练,这个数字需要小于或等于当前系统上的 GPU 数量(nproc_per_node),并且每个进程将在 GPU 0 到 GPU (nproc_per_node - 1) 上运行。

如何使用该模块:

1、单节点多进程分布式训练


python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVEYOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all otherarguments of your training script)

2、多节点多进程分布式训练:(例如两个节点)

节点1:(IP: 192.168.1.1,空闲端口:1234)


python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE--nnodes=2 --node-rank=0 --master-addr="192.168.1.1"--master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3and all other arguments of your training script)

Node 2:



python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE--nnodes=2 --node-rank=1 --master-addr="192.168.1.1"--master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3and all other arguments of your training script)

3、要查看该模块提供的可选参数:

python -m torch.distributed.launch --help

重要注意事项:

1、当前该工具及多进程分布式(单节点或多节点)GPU训练仅在NCCL分布式后端下才能实现最佳性能。因此,推荐在GPU训练中使用NCCL后端。

2、在训练程序中,必须解析命令行参数:

--local-rank=LOCAL_PROCESS_RANK(该参数将由本模块提供)。

若训练程序使用GPU,需确保代码仅在LOCAL_PROCESS_RANK对应的GPU设备上运行。可通过以下方式实现:

解析local_rank参数


>>> import argparse>>> parser = argparse.ArgumentParser()>>> parser.add_argument("--local-rank", "--local_rank", type=int)>>> args = parser.parse_args()

将您的设备设置为本地等级,可通过以下方式实现:

>>> torch.cuda.set_device(args.local_rank)  # 在代码运行前执行此操作

or


>>> with torch.cuda.device(args.local_rank):
>>>    # 在此处运行你的代码
>>>    ...

版本 2.0.0 变更:启动器会向您的脚本传递 --local-rank=<rank> 参数。

从 PyTorch 2.0.0 开始,推荐使用带连字符的 --local-rank 而非之前使用的带下划线形式 --local_rank

为了保持向后兼容性,用户可能需要在参数解析代码中同时处理这两种情况。这意味着在参数解析器中需要同时包含 "--local-rank""--local_rank"。如果仅提供 "--local_rank",启动器会报错:“error: unrecognized arguments: –local-rank=”。对于仅支持 PyTorch 2.0.0+ 的训练代码,包含 "--local-rank" 应该就足够了。

3、在您的训练程序中,应当在开始时调用以下函数来启动分布式后端。强烈建议使用 init_method=env://。其他初始化方法(如 tcp://)可能有效,但 env:// 是本模块官方支持的方式。


>>> torch.distributed.init_process_group(backend='YOUR BACKEND', init_method='env://')

在训练程序中,您可以选择使用常规的分布式函数,也可以使用 torch.nn.parallel.DistributedDataParallel() 模块。如果您的训练程序使用 GPU 进行训练,并且希望使用 torch.nn.parallel.DistributedDataParallel() 模块,以下是配置方法。


>>> model = torch.nn.parallel.DistributedDataParallel(model, >>                                                  device_ids=[args.local_rank], >>                                                  output_device=args.local_rank)

请确保将 device_ids 参数设置为代码将操作的唯一 GPU 设备 ID。这通常是进程的本地排名(local rank)。换句话说,要使用此工具,device_ids 需设为 [args.local_rank],且 output_device 需设为 args.local_rank

5、另一种通过环境变量 LOCAL_RANK 向子进程传递 local_rank 的方法:当使用 --use-env=True 启动脚本时,此功能会自动启用。你必须修改上述子进程示例,将 args.local_rank 替换为 os.environ['LOCAL_RANK'];若指定该标志,启动器将不会传递 --local-rank 参数。

警告:local_rank 并非全局唯一,它仅在单台机器的进程内唯一。因此,切勿用它来决定是否执行诸如写入网络文件系统等操作。若未正确处理,可能导致问题,具体案例可参考 https://github.com/pytorch/pytorch/issues/12042。


生成进程工具

多进程包 - torch.multiprocessing 提供了 torch.multiprocessing.spawn() 中的 spawn 函数。这个辅助函数可用于生成多个进程,其工作原理是传入目标执行函数,然后创建N个进程来运行该函数。该工具也可用于多进程分布式训练。

具体用法示例可参考 PyTorch示例 - ImageNet实现

注意:此功能需要Python 3.4或更高版本。


调试 torch.distributed 应用程序

由于难以理解的挂起、崩溃或跨进程的不一致行为,调试分布式应用程序可能具有挑战性。torch.distributed 提供了一套工具,以自助方式帮助调试训练应用程序:

Python 断点调试

在分布式环境中使用Python调试器极为便利,但由于开箱即用功能不足,许多人完全未使用它。PyTorch提供了一个定制化的pdb封装器,可简化这一流程。


torch.distributed.breakpoint 使该过程变得简单。其内部通过两种方式定制pdb的断点行为,其余功能与常规pdb一致:
1、仅在被用户指定的特定rank上附加调试器
2、通过调用torch.distributed.barrier()确保其他所有rank暂停运行,该屏障会在被调试rank发出继续指令后解除
3、将子进程的标准输入重定向至您的终端

使用时,只需在所有rank上调用torch.distributed.breakpoint(rank),并确保各rank传入相同的rank值即可。


监控式屏障

从 v1.10 版本开始,torch.distributed.monitored_barrier() 作为 torch.distributed.barrier() 的替代方案存在。当发生崩溃时(即并非所有 rank 在指定超时时间内调用 torch.distributed.monitored_barrier()),该函数会提供有关可能故障 rank 的有用信息。torch.distributed.monitored_barrier() 通过类似确认机制的 send/recv 通信原语在主机端实现屏障功能,使得 rank 0 能够报告哪些 rank 未能及时确认屏障。

例如,考虑以下场景:rank 1 未能调用 torch.distributed.monitored_barrier()(实际中可能由于应用程序错误或前一个集合操作挂起导致):

import osfrom datetime import timedeltaimport torchimport torch.distributed as distimport torch.multiprocessing as mpdef worker(rank):dist.init_process_group("nccl", rank=rank, world_size=2)

监控屏障需要 gloo 进程组执行主机端同步


group_gloo = dist.new_group(backend="gloo")if rank not in [1]:dist.monitored_barrier(group=group_gloo, timeout=timedelta(seconds=2))if __name__ == "__main__":os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "29501"mp.spawn(worker, nprocs=2, args=())

在 rank 0 上会产生以下错误信息,使用户能够判断哪些 rank 可能出现故障并进行进一步排查:

RuntimeError: Rank 1 failed to pass monitoredBarrier in 2000 msOriginal exception:
[gloo/transport/tcp/pair.cc:598] Connection closed by peer [2401:db00:eef0:1100:3560:0:1c05:25d]:8594

说明:

1、保留了代码块格式和所有技术术语(如RuntimeErrorRankmonitoredBarrier

2、将被动语态"Connection closed by peer"转换为主动语态"对等方关闭了连接"

3、保持了IP地址和端口号的原始格式

4、错误信息路径[gloo/transport/tcp/pair.cc:598]保持原样

5、时间单位"ms"转换为中文习惯的"毫秒"


TORCH_DISTRIBUTED_DEBUG

当设置 TORCH_CPP_LOG_LEVEL=INFO 时,环境变量 TORCH_DISTRIBUTED_DEBUG 可用于触发额外的有用日志记录和集体同步检查,以确保所有进程能正确同步。根据所需的调试级别,TORCH_DISTRIBUTED_DEBUG 可设置为 OFF(默认)、INFODETAIL。请注意,最详细的选项 DETAIL 可能会影响应用程序性能,因此应仅在调试问题时使用。

设置 TORCH_DISTRIBUTED_DEBUG=INFO 会在初始化使用 torch.nn.parallel.DistributedDataParallel() 训练的模型时生成额外的调试日志;而设置 TORCH_DISTRIBUTED_DEBUG=DETAIL 还会在选定的迭代次数中记录运行时性能统计信息。这些运行时统计信息包括前向传播时间、反向传播时间、梯度通信时间等数据。例如,给定以下应用程序:

import osimport torch
import torch.distributed as dist
import torch.multiprocessing as mpclass TwoLinLayerNet(torch.nn.Module):def __init__(self):super().__init__()self.a = torch.nn.Linear(10, 10, bias=False)self.b = torch.nn.Linear(10, 1, bias=False)def forward(self, x):a = self.a(x)b = self.b(x)return (a, b)def worker(rank):dist.init_process_group("nccl", rank=rank, world_size=2)torch.cuda.set_device(rank)print("init model")model = TwoLinLayerNet().cuda()print("init ddp")ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])inp = torch.randn(10, 10).cuda()print("train")for _ in range(20):output = ddp_model(inp)loss = output[0] + output[1]loss.sum().backward()if __name__ == "__main__":os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "29501"os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"  # set to DETAIL for runtime logging.mp.spawn(worker, nprocs=2, args=())

初始化时会渲染以下日志:

I0607 16:10:35.739390 515217 logger.cpp:173] [Rank 0]: DDP Initialized with:
broadcast_buffers: 1
bucket_cap_bytes: 26214400
find_unused_parameters: 0
gradient_as_bucket_view: 0
is_multi_device_module: 0
iteration: 0
num_parameter_tensors: 2
output_device: 0
rank: 0
total_parameter_size_bytes: 440
world_size: 2
backend_name: nccl
bucket_sizes: 440
cuda_visible_devices: N/A
device_ids: 0
dtypes: float
master_addr: localhost
master_port: 29501
module_name: TwoLinLayerNet
nccl_async_error_handling: N/A
nccl_blocking_wait: N/A
nccl_debug: WARN
nccl_ib_timeout: N/A
nccl_nthreads: N/A
nccl_socket_ifname: N/A
torch_distributed_debug: INFO

运行时(当设置 TORCH_DISTRIBUTED_DEBUG=DETAIL 时)会显示以下日志:

I0607 16:18:58.085681 544067 logger.cpp:344] [Rank 1 / 2] Training TwoLinLayerNet unused_parameter_size=0Avg forward compute time: 40838608Avg backward compute time: 5983335
Avg backward comm. time: 4326421Avg backward comm/comp overlap time: 4207652
I0607 16:18:58.085693 544066 logger.cpp:344] [Rank 0 / 2] Training TwoLinLayerNet unused_parameter_size=0Avg forward compute time: 42850427Avg backward compute time: 3885553
Avg backward comm. time: 2357981Avg backward comm/comp overlap time: 2234674

此外,TORCH_DISTRIBUTED_DEBUG=INFO 增强了 torch.nn.parallel.DistributedDataParallel() 中因模型存在未使用参数导致的崩溃日志记录。当前,如果前向传播中存在可能未被使用的参数,必须在初始化 torch.nn.parallel.DistributedDataParallel() 时传入 find_unused_parameters=True。从 v1.10 开始,由于 torch.nn.parallel.DistributedDataParallel() 不支持反向传播中存在未使用参数,所有模型输出都必须参与损失计算。这些限制对大型模型尤其具有挑战性。

因此当发生错误崩溃时,torch.nn.parallel.DistributedDataParallel() 会记录所有未被使用参数的完全限定名称。例如在上述应用中,如果将损失计算改为 loss = output[1],那么 TwoLinLayerNet.a 在反向传播中不会接收梯度,从而导致 DDP 失败。崩溃时,系统会向用户提供关于未使用参数的信息——这对于大型模型而言可能难以手动定位。


RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passingthe keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
making sure all `forward` function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return va
lue of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameters which did not receive grad for rank 0: a.weight
Parameter indices which did not receive grad for rank 0: 0

设置 TORCH_DISTRIBUTED_DEBUG=DETAIL 会触发对用户发起的每个集体调用(无论是直接调用还是间接调用,例如 DDP 的 allreduce)进行额外的同步性和一致性检查。具体实现方式是创建一个包装器进程组,该包装器会包裹所有通过 torch.distributed.init_process_group()torch.distributed.new_group() API 返回的进程组。因此,这些 API 将返回一个包装器进程组,其使用方式与常规进程组完全相同,但在将集体操作分发给底层进程组之前会执行一致性检查。

目前,这些检查包括调用 torch.distributed.monitored_barrier(),该操作会确保所有节点完成未完成的集体调用,并报告卡住的节点。接着,系统会通过验证所有集体函数是否匹配且使用一致的张量形状来检查集体操作本身的一致性。如果不符合条件,应用程序崩溃时会提供包含详细错误信息的报告,而不是直接挂起或返回无意义的错误消息。例如,考虑以下函数中传入 torch.distributed.all_reduce() 的张量形状不匹配的情况:

import torch
import torch.distributed as dist
import torch.multiprocessing as mpdef worker(rank):dist.init_process_group("nccl", rank=rank, world_size=2)torch.cuda.set_device(rank)tensor = torch.randn(10 if rank == 0 else 20).cuda()dist.all_reduce(tensor)torch.cuda.synchronize(device=rank)if __name__ == "__main__":os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "29501"os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"mp.spawn(worker, nprocs=2, args=())

使用NCCL后端时,这类应用很可能会导致程序挂起,在复杂场景下难以定位根本原因。如果用户启用

TORCH_DISTRIBUTED_DEBUG=DETAIL并重新运行应用,以下错误信息会揭示根本原因:

work = default_pg.allreduce([tensor], opts)
RuntimeError: Error when verifying shape tensors for collective ALLREDUCE on rank 0、This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes:  10
[ torch.LongTensor{1} ]

注意:如需在运行时对调试级别进行细粒度控制,还可以使用以下函数:torch.distributed.set_debug_level()torch.distributed.set_debug_level_from_env()torch.distributed.get_debug_level()

此外,可以将 TORCH_DISTRIBUTED_DEBUG=DETAILTORCH_SHOW_CPP_STACKTRACES=1 结合使用,以便在检测到集合操作不同步时记录完整的调用堆栈。这些集合操作不同步检查适用于所有使用 c10d 集合调用的应用程序,这些调用由通过 torch.distributed.init_process_group()torch.distributed.new_group() API 创建的进程组支持。


日志记录

除了通过 torch.distributed.monitored_barrier()TORCH_DISTRIBUTED_DEBUG 提供的显式调试支持外,torch.distributed 的底层 C++ 库还会输出不同级别的日志消息。这些消息有助于理解分布式训练作业的执行状态,并排查诸如网络连接故障等问题。下表展示了如何通过组合 TORCH_CPP_LOG_LEVELTORCH_DISTRIBUTED_DEBUG 环境变量来调整日志级别。

TORCH_CPP_LOG_LEVELTORCH_DISTRIBUTED_DEBUG实际日志级别
ERROR忽略错误
WARNING忽略警告
INFO忽略信息
INFOINFO调试
INFODETAIL跟踪(即全部)

分布式组件会抛出从 RuntimeError 派生的自定义异常类型:

  • torch.distributed.DistError:这是所有分布式异常的基类型。
  • torch.distributed.DistBackendError:当发生后端特定错误时抛出此异常。例如,如果使用 NCCL 后端且用户尝试使用 NCCL 库不可用的 GPU。
  • torch.distributed.DistNetworkError:当网络库遇到错误时抛出此异常(例如:连接被对端重置)。
  • torch.distributed.DistStoreError:当 Store 遇到错误时抛出此异常(例如:TCPStore 超时)。

class torch.distributed.DistError

分布式库中发生错误时引发的异常


class torch.distributed.DistBackendError

当分布式系统中发生后端错误时引发的异常


class torch.distributed.DistNetworkError

分布式系统中发生网络错误时引发的异常


class torch.distributed.DistStoreError

分布式存储发生错误时引发的异常

如果正在运行单节点训练,可以方便地以交互方式在脚本中设置断点。我们提供了一种便捷的方法来为单个 rank 设置断点:

torch.distributed.breakpoint(rank=0, skip=0)

功能说明

设置断点,但仅对单个指定rank生效。其他所有rank会等待该断点执行完成后才继续运行。

参数说明

  • rank (int) – 指定触发断点的rank编号,默认为0
  • skip (int) – 跳过前skip次对该断点的调用,默认为0

torch.distributed.tensor


注意:torch.distributed.tensor 目前处于 alpha 开发阶段,文档中列出的大部分 API 我们将确保向后兼容性,但必要时可能会进行 API 变更。


PyTorch DTensor(分布式张量)

PyTorch DTensor 提供简单灵活的张量分片原语,能够透明处理分布式逻辑,包括跨设备/主机的分片存储、算子计算和集合通信。DTensor 可用于构建不同的并行解决方案,并支持在多维分片场景下表示分片状态的 state_dict。

以下是基于 DTensor 构建的 PyTorch 原生并行方案示例:

  • 张量并行
  • FSDP2

DTensor 遵循 SPMD(单程序多数据)编程模型,让用户能够像编写具有相同收敛特性的单设备程序那样编写分布式程序。它通过指定 DeviceMeshPlacement 提供统一的张量分片布局(DTensor 布局):

  • DeviceMesh 使用 n 维数组表示集群的设备拓扑和通信器
  • Placement 描述逻辑张量在 DeviceMesh 上的分片布局
    DTensor 支持三种分片类型:Shard(分片)、Replicate(复制)和 Partial(部分)。

DTensor 类 API

DTensortorch.Tensor 的子类。这意味着一旦创建了 DTensor,就可以以与 torch.Tensor 非常相似的方式使用它,包括运行不同类型的 PyTorch 操作符,就像在单个设备上运行它们一样,同时为 PyTorch 操作符提供正确的分布式计算支持。

除了现有的 torch.Tensor 方法外,它还提供了一组额外的方法来与 torch.Tensor 交互、将 DTensor 布局重新分配到新的 DTensor、获取所有设备上的完整张量内容等。


class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad) 

DTensor(分布式张量)是 torch.Tensor 的子类,它为多设备 torch.Tensor 提供了类似单设备的编程抽象。它通过 DeviceMesh 和以下类型的 Placement 来描述分布式张量的分片布局(DTensor Layout):

  • Shard:张量在 DeviceMesh 维度的设备上沿张量维度 dim 分片
  • Replicate:张量在 DeviceMesh 维度的设备上完整复制
  • Partial:张量在 DeviceMesh 维度的设备上待规约

当调用 PyTorch 算子时,DTensor 会重载这些算子以执行分片计算,并在必要时发起通信。在算子计算过程中,DTensor 会根据算子本身的语义正确转换或传播布局(DTensor Layout),并生成新的 DTensor 输出。

为确保调用 PyTorch 算子时 DTensor 分片计算的数值正确性,DTensor 要求算子的每个 Tensor 参数都必须是 DTensor。


注意:直接使用 Tensor 子类构造函数创建 DTensor 并非推荐方式(例如它无法正确处理自动求导,因此不属于公开 API)。请参阅 create_dtensor 章节了解如何正确创建 DTensor

返回类型:DTensor


__create_chunk_list__()

返回一个 ChunkStorageMetadata 列表,该数据类用于描述当前 rank 上本地分片/副本的大小和偏移量。对于 DTensor,每个 rank 只会有一个本地分片/副本,因此返回的列表通常仅包含一个元素。

此双下划线方法主要用于分布式检查点用途。

返回值:一个 List[ChunkStorageMetadata] 对象,表示当前 rank 上的分片大小/偏移量。


static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)

根据指定的 device_meshplacements,从各 rank 上的本地 torch.Tensor 创建一个 DTensor

参数

  • local_tensor (torch.Tensor) – 各 rank 上的本地 torch.Tensor。
  • device_mesh (DeviceMesh, 可选) – 用于放置张量的 DeviceMesh。若未指定,则必须在 DeviceMesh 上下文管理器中调用,默认值:None
  • placements (List[Placement], 可选) – 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的布局列表,其元素数量必须与 device_mesh.ndim 相同。

关键字参数

  • run_check ([bool], 可选) – 以额外通信为代价,跨 rank 执行完整性检查,验证各本地张量的元信息以确保正确性。若 placements 中包含 Replicate,设备网格维度的第一个 rank 上的数据将被广播到其他 rank。默认值:False
  • shape ( torch.Size , 可选) – 指定构建在 local_tensor 之上的 DTensor 大小的整型列表。注意:当各 rank 上 local_tensor 的形状不同时必须提供此参数。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算 shape。默认值:None
  • stride ( tuple , 可选) – 指定 DTensor 步长的整型列表。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算 stride。默认值:None

返回

一个 DTensor 对象

返回类型:DTensor

注意:当 run_check=False 时,用户需自行确保传入的本地张量在各 rank 间正确(即对于 Shard(dim) 布局张量需分片,对于 Replicate() 布局需复制)。否则,所创建 DTensor 的行为将是未定义的。

注意:from_local 是可微操作,创建的 DTensor 对象的 requires_grad 属性将取决于 local_tensor 是否 requires_grad。


full_tensor(*, grad_placements=None) 

返回该DTensor的完整张量。该方法会执行必要的集合通信操作,从所在DeviceMesh的其他rank上收集本地张量并进行拼接。这是以下代码的语法糖:

dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()

关键字参数

  • grad_placements (List[Placement], 可选) – 该参数描述了从本函数返回的完整张量对应的梯度布局的未来分布方式。

full_tensor将DTensor转换为完整的torch.Tensor,但返回的torch.tensor在后续代码中可能不会保持原始复制的DTensor布局。这个参数是用户提供给autograd的提示,用于处理返回张量的梯度布局与原始复制的DTensor布局不匹配的情况。如果未指定,我们将假定完整张量的梯度布局为复制式分布。

返回值:一个表示该DTensor完整张量的torch.Tensor对象。

返回类型: Tensor

注意:full_tensor是可微分的。


redistribute(device_mesh=None, placements=None, *, async_op=False)

redistribute 执行必要的集体操作,将当前 DTensor 从其现有布局重新分配到新布局,或从当前 DeviceMesh 迁移到新 DeviceMesh。例如,我们可以通过为 DeviceMesh 的每个维度指定 Replicate 布局,将分片(Sharded)DTensor 转换为复制(Replicated)DTensor。

当在 DeviceMesh 的某个维度上从当前布局重新分配到新布局时,将执行以下包含通信集体操作或本地操作:

1、Shard(dim)Replicate()all_gather

2、Shard(src_dim)Shard(dst_dim)all_to_all

3、Replicate()Shard(dim):本地分块(即 torch.chunk

4、Partial()Replicate()all_reduce

5、Partial()Shard(dim)reduce_scatter

redistribute 能够正确推断出针对在 1-D 或 N-D DeviceMesh 上创建的 DTensor 所需的重新分配步骤。

参数

  • device_mesh (DeviceMesh, 可选) – 用于放置 DTensor 的 DeviceMesh。若未指定,则使用当前 DTensor 的 DeviceMesh。

默认值:None

  • placements (List[Placement], 可选) – 描述如何将 DTensor 放置到 DeviceMesh 中的新布局,其元素数量必须与 device_mesh.ndim 相同。

默认值:在所有网格维度上复制(replicate)

关键字参数

  • async_op ([bool], 可选) – 是否以异步方式执行 DTensor 重新分配操作。默认值:False

返回

一个 DTensor 对象

返回类型

DTensor

注意redistribute 是可微分的,这意味着用户无需担心重新分配操作的反向传播公式。

注意redistribute 当前仅支持在同一 DeviceMesh 上重新分配 DTensor。若需将 DTensor 重新分配到不同 DeviceMesh,请提交问题。


to_local(*, grad_placements=None) 

获取当前 rank 上该 DTensor 的本地张量。对于分片情况,返回逻辑张量视图的本地分片;对于复制情况,返回当前 rank 上的副本。

关键字参数

  • grad_placements (List[Placement], 可选) – 该参数描述从本函数返回张量的梯度未来布局。

to_local 将 DTensor 转换为本地张量,且返回的本地张量后续可能不会沿用原 DTensor 的布局。此参数是用户提供给自动求导的提示,用于处理返回张量的梯度布局与原 DTensor 不匹配的情况。若未指定,则默认梯度布局与原 DTensor 相同并用于梯度计算。

返回值:一个 torch.TensorAsyncCollectiveTensor 对象,表示当前 rank 上的本地张量。当返回 AsyncCollectiveTensor 对象时,意味着本地张量尚未就绪(即通信未完成)。此时用户需调用 wait 方法等待本地张量准备就绪。

返回类型 : Tensor

注意:to_local 是可微分的,返回本地张量的 requires_grad 属性将取决于原 DTensor 是否要求梯度。


property device_mesh: [DeviceMesh](distributed.html#torch.distributed.device_mesh.DeviceMesh "torch.distributed.device_mesh.DeviceMesh")

与该 DTensor 对象关联的 DeviceMesh 属性。

注意:device_mesh 是一个只读属性,不可被设置。


property placements:  tuple [[torch.distributed.tensor.placement_types.Placement](https://pytorch.org/docs/stable/data.html#torch.distributed.tensor.placement_types.Placement "torch.distributed.tensor.placement_types.Placement"),...]

该 DTensor 的 placements 属性描述了其在设备网格(DeviceMesh)上的分布布局。

注意placements 是只读属性,不可被修改。


作为分布式通信器的DeviceMesh

DeviceMesh基于DTensor构建,用于抽象描述集群设备拓扑结构,并作为多维通信器(基于ProcessGroup)的载体。如需了解如何创建/使用DeviceMesh的具体细节,请参阅DeviceMesh使用指南。


DTensor 布局类型

DTensor 支持在每个 DeviceMesh 维度上使用以下 Placement 类型:

class torch.distributed.tensor.placement_types.Shard(dim)

Shard(dim)布局描述了张量在维度dim上跨对应DeviceMesh维度的分片方式,其中DeviceMesh维度上的每个rank仅持有全局张量的一个分片。Shard(dim)布局遵循torch.chunk(dim)语义——当张量维度无法在DeviceMesh维度上均匀划分时,DeviceMesh维度上的最后几个分片可能为空。所有DTensor API(如distribute_tensorfrom_local等)均可使用Shard布局。

参数

  • dim (int) - 指定张量在对应DeviceMesh维度上进行分片的维度编号。

警告:当前对无法在DeviceMesh维度上均匀划分的张量维度进行分片属于实验性功能,后续可能变更。

dim: int


class torch.distributed.tensor.placement_types.Replicate

Replicate()布局描述了DTensor在对应的DeviceMesh维度上进行复制的行为,其中DeviceMesh维度上的每个rank都持有全局Tensor的一个副本。所有DTensor API(例如distribute_tensorDTensor.from_local等)都可以使用Replicate布局。


class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')

Partial(reduce_op)布局描述了在指定DeviceMesh维度上待归约的DTensor,其中DeviceMesh维度的每个rank持有全局Tensor的部分值。用户可以通过redistributePartial DTensor转换为指定DeviceMesh维度上的ReplicateShard(dim)布局,这将触发底层的必要通信操作(如allreducereduce_scatter)。

参数

  • reduce_op (str, 可选) – 用于将Partial DTensor转换为Replicated/Sharded DTensor的归约操作。仅支持逐元素的归约操作,包括:“sum”、“avg”、“product”、“max”、“min”,默认值为"sum"。

注意:Partial布局可能作为DTensor运算符的结果生成,且只能通过DTensor.from_local API使用。


reduce_op: str = 'sum'

class torch.distributed.tensor.placement_types.Placement

Placement 类型的基类,用于描述如何将 DTensor 放置在 DeviceMesh 上。PlacementDeviceMesh 共同定义了 DTensor 的布局。

它是三种主要 DTensor 放置类型(ShardReplicatePartial)的基类。

这个类不直接使用,主要作为类型标注存根。


is_partial() 

返回类型:bool


is_replicate()

返回类型:bool


is_shard(dim=None) 

返回类型:bool


创建 DTensor 的不同方式

有三种方法可以构建 DTensor

  • distribute_tensor() 从每个 rank 上的逻辑或"全局" torch.Tensor 创建 DTensor。这可用于对叶子节点 torch.Tensor(即模型参数/缓冲区和输入)进行分片。

  • DTensor.from_local() 从每个 rank 上的本地 torch.Tensor 创建 DTensor,可用于从非叶子节点 torch.Tensor(即前向/反向传播过程中的中间激活张量)创建 DTensor

  • DTensor 提供了专门的张量工厂函数(如 empty()ones()randn() 等),通过直接指定 DeviceMeshPlacement 来创建不同的 DTensor。与 distribute_tensor() 相比,这种方法可以直接在设备上实现分片内存,而不是在初始化逻辑张量内存后再执行分片操作。


从逻辑上的 torch.Tensor 创建 DTensor

torch.distributed 中的 SPMD(单程序多数据)编程模型会启动多个进程(例如通过 torchrun)来执行同一程序。这意味着程序内部的模型会先在不同进程上初始化(例如模型可能在 CPU、元设备上初始化,或者如果有足够内存则直接在 GPU 上初始化)。

DTensor 提供了一个 distribute_tensor() API,可以将模型权重或张量分片为多个 DTensor。该 API 会在每个进程上从“逻辑”张量创建 DTensor,从而使生成的 DTensor 遵循单一设备语义,这对于数值正确性至关重要。


torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)

根据指定的placements将叶子节点torch.Tensor(如nn.Parameter/缓冲区)分发到device_meshdevice_meshplacements的维度必须相同。待分发的tensor是逻辑或"全局"张量,该API会使用DeviceMesh第一个维度的首秩张量作为数据源以保持单设备语义。若需在自动梯度计算过程中构建DTensor,请改用DTensor.from_local()

参数说明

  • tensor (torch.Tensor) – 待分发的张量。注意:若需在设备网格维度上对无法整除的张量进行分片,将使用torch.chunk语义进行分片和散射。非均匀分片行为尚处实验阶段,后续可能变更。
  • device_mesh (DeviceMesh, 可选) – 目标设备网格。若未指定,必须在DeviceMesh上下文管理器中调用,默认值:None
  • placements (List[Placement], 可选) – 描述张量在设备网格上分布方式的定位策略,元素数量必须与device_mesh.ndim相同。若未指定,默认会沿设备网格各维度的首秩复制张量。

关键字参数

  • src_data_rank ( int , 可选) – 逻辑/全局张量的源数据秩,distribute_tensor()通过此参数将分片/副本散射/广播到其他秩。默认使用各DeviceMesh维度上group_rank=0作为数据源以保持单设备语义。若显式传入None,该API将直接使用本地数据而非通过散射/广播保持单设备语义。默认值:0

返回值
返回DTensorXLAShardedTensor对象。

返回类型
DTensor

注意:当使用xla设备类型初始化DeviceMesh时,distribute_tensor会返回XLAShardedTensor。详见此问题。XLA集成功能尚处实验阶段,后续可能变更。

distribute_tensor()外,DTensor还提供distribute_module()API,可在nn.Module层级实现更便捷的分片操作。


torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)

该函数提供了三个功能来控制模块的参数/输入/输出:

1、通过在运行时执行前指定 partition_fn 对模块进行分片处理(即允许用户根据指定的 partition_fn 将模块参数转换为 DTensor 参数)。

2、通过在运行时执行时指定 input_fnoutput_fn 来控制模块的输入或输出(即将输入转换为 DTensor,将输出转换回 torch.Tensor)。

参数

  • module (nn.Module) – 需要分片的用户模块。
  • device_mesh (DeviceMesh) – 用于放置模块的设备网格。
  • partition_fn (Callable) – 用于分片参数的函数(即在 device_mesh 上切分特定参数)。如果未指定 partition_fn,默认会在网格上复制 module 的所有模块参数。
  • input_fn (Callable) – 指定输入分布,即可以控制模块输入的切分方式。input_fn 会作为模块的 forward_pre_hook(前向钩子)安装。
  • output_fn (Callable) – 指定输出分布,即可以控制输出的切分方式,或将其转换回 torch.Tensor。output_fn 会作为模块的 forward_hook(后向钩子)安装。

返回

一个包含所有参数/缓冲区的模块,这些参数/缓冲区均为 DTensor 类型。

返回类型:Module

注意:当使用 xla 设备类型初始化 DeviceMesh 时,distribute_module 会返回带有 PyTorch/XLA SPMD 注释参数的 nn.Module。详情请参阅此问题。XLA 集成目前处于实验阶段,可能会发生变化。


DTensor 工厂函数

DTensor 还提供了专门的张量工厂函数,允许直接创建 DTensor。这些函数使用类似 torch.Tensor 的工厂函数 API(例如 torch.ones、torch.empty 等),并通过额外指定 DeviceMeshPlacement 来配置所创建的 DTensor


torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None) 

返回一个用标量值0填充的DTensor

参数

  • size ( int *...) - 定义输出DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:zeros(1,2,3…) 或 zeros([1,2,3…]) 或 zeros((1,2,3…))

关键字参数

  • requires_grad ([bool], 可选) - 如果为True,自动微分将记录对返回DTensor的操作。默认值:False
  • dtype ( torch.dtype , 可选) - 返回DTensor的期望数据类型。默认值:如果为None,则使用全局默认值(参见torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) - 返回DTensor的期望布局。默认值:torch.strided
  • device_mesh - DeviceMesh类型,包含rank的网格信息
  • placements - Placement类型的序列:ShardReplicate

返回

每个rank上的一个DTensor对象

返回类型

DTensor


torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)

返回一个填充了标量值1的DTensor,其形状由可变参数size定义。

参数

  • size ( int *...) – 定义输出DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))

关键字参数

  • dtype ( torch.dtype , 可选) – 返回DTensor的期望数据类型。默认值:如果为None,则使用全局默认值(参见torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) – 返回DTensor的期望布局。默认值:torch.strided
  • requires_grad ([bool], 可选) – 是否应在返回的DTensor上记录自动梯度操作。默认值:False
  • device_meshDeviceMesh类型,包含进程的网格信息
  • placementsPlacement类型的序列:ShardReplicate

返回值:每个进程上的一个DTensor对象

返回类型:DTensor


torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None) 

返回一个填充了未初始化数据的 DTensor。该 DTensor 的形状由可变参数 size 定义。

参数

  • size ( int *...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:empty(1,2,3…)、empty([1,2,3…]) 或 empty((1,2,3…))。

关键字参数

  • dtype ( torch.dtype , 可选) – 返回 DTensor 的期望数据类型。默认值:如果为 None,则使用全局默认值(参见 torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) – 返回 DTensor 的期望布局。默认值:torch.strided
  • requires_grad ([bool], 可选) – 是否在返回的 DTensor 上记录自动求导操作。默认值:False
  • device_meshDeviceMesh 类型,包含进程的网格信息。
  • placementsPlacement 类型的序列:ShardReplicate

返回值:每个进程上的一个 DTensor 对象。

返回类型:DTensor


torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)

根据 device_meshplacements 参数,返回一个填充了 fill_valueDTensor,其形状由参数 size 定义。

参数

  • size ( int *...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数,也可以是列表或元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))。
  • fill_value (Scalar) – 用于填充输出张量的值。

关键字参数

  • dtype ( torch.dtype , 可选) – 返回的 DTensor 所需的数据类型。默认值:如果为 None,则使用全局默认值(参见 torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) – 返回的 DTensor 所需的布局。默认值:torch.strided
  • requires_grad ([bool], 可选) – 是否应自动梯度记录对返回的 DTensor 的操作。默认值:False
  • device_meshDeviceMesh 类型,包含 rank 的网格信息。
  • placementsPlacement 类型的序列:ShardReplicate

返回

每个 rank 上的一个 DTensor 对象。

返回类型

DTensor


torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None) 

返回一个填充了区间 [0, 1) 上均匀分布随机数的 DTensor。张量的形状由可变参数 size 定义。

参数

  • size (int *...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:ones(1,2,3…)、ones([1,2,3…]) 或 ones((1,2,3…))。

关键字参数

  • dtype (torch.dtype, 可选) – 返回的 DTensor 所需的数据类型。默认值:如果为 None,则使用全局默认值(参见 torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) – 返回的 DTensor 所需的布局。默认值:torch.strided
  • requires_grad ([bool], 可选) – 如果为 True,则自动微分会记录对返回的 DTensor 的操作。默认值:False
  • device_meshDeviceMesh 类型,包含进程的网格信息。
  • placementsPlacement 类型的序列:ShardReplicate

返回

每个进程上的一个 DTensor 对象。

返回类型

DTensor


torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)

返回一个填充了均值为0、方差为1的正态分布随机数的DTensor,张量的形状由变量参数size定义。

参数

  • size (int *...) - 定义输出DTensor形状的整数序列。可以是可变数量的参数或列表/元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))

关键字参数

  • dtype (torch.dtype, 可选) - 返回DTensor的期望数据类型。默认值:如果为None,则使用全局默认值(参见torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) - 返回DTensor的期望布局。默认值:torch.strided
  • requires_grad ([bool], 可选) - 是否应在返回的DTensor上记录自动求导操作。默认值:False
  • device_mesh - DeviceMesh类型,包含rank的网格信息。
  • placements - Placement类型的序列:ShardReplicate

返回

每个rank上的一个DTensor对象

返回类型

DTensor


调试


日志记录

启动程序时,可以通过设置 torch._logging 中的 TORCH_LOGS 环境变量来启用额外的日志记录功能:

  • TORCH_LOGS=+dtensor 将显示 logging.DEBUG 及以上级别的日志消息
  • TORCH_LOGS=dtensor 将显示 logging.INFO 及以上级别的日志消息
  • TORCH_LOGS=-dtensor 将显示 logging.WARNING 及以上级别的日志消息

调试工具

为了调试应用了DTensor的程序,并深入了解底层发生的集合通信细节,DTensor提供了CommDebugMode调试模式:

class torch.distributed.tensor.debug.CommDebugMode 

CommDebugMode 是一个上下文管理器,用于统计其上下文中功能集合操作的次数。它通过 TorchDispatchMode 实现这一功能。

注意:目前并非所有集合操作都受支持。

使用示例


mod = ...
comm_mode = CommDebugMode()
with comm_mode:mod.sum().backward()
print(comm_mode.get_comm_counts())

generate_comm_debug_tracing_table(noise_level=3)

生成详细表格,展示模块层级的操作和集体追踪信息。信息量取决于 noise_level 参数:

0、打印模块层级的集体调用次数统计

1、打印未包含在简单操作中的 dTensor 操作及模块信息

2、打印未包含在简单操作中的所有操作

3、打印全部操作


generate_json_dump(file_name='comm_mode_log.json', noise_level=3) 

生成用于构建浏览器可视化的json文件

0、打印模块级别的聚合计数

1、打印未包含在简单操作中的dTensor运算

2、打印未包含在简单操作中的运算

3、打印所有运算


get_comm_counts()

返回通信计数作为字典。

返回值:以字典形式返回通信计数。

返回类型:Dict[Any, int]


get_parameter_info() 

返回类型:dict[str , dict[str , Any ]


get_sharding_info()

返回类型 : dict[str, dict[str, Any]]


get_total_counts()

返回类型:int


log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)

替代控制台 CommDebugMode 输出的方案,可将日志写入用户指定的文件

为了可视化维度少于 3 的 DTensor 分片情况,DTensor 提供了 visualize_sharding() 方法:

torch.distributed.tensor.debug.visualize_sharding(dtensor, header='') 

在终端中可视化一维或二维 DTensor 的分片情况。


注意:需安装 tabulate 包。空张量不会显示分片信息。


实验性功能

DTensor 还提供了一系列实验性功能。这些功能要么处于原型开发阶段,要么基础功能已完成但正在收集用户反馈。如果您对这些功能有任何意见,请向 PyTorch 提交 issue。


torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)

context_parallel 是一个实验性 API,用于实现上下文并行(CP)。该 API 执行两个操作:1) 将 SDPA(torch.nn.functional.scaled_dot_product_attention)替换为支持 CP 的版本;2) 沿序列维度对 buffers 进行分片,每个 rank 根据 mesh 保留对应的分片。

参数

  • mesh (DeviceMesh) – 用于上下文并行的设备网格。
  • buffers (Optional[List[torch.Tensor]]) – 依赖序列维度的缓冲区。例如输入批次、标签和位置嵌入缓冲区。这些缓冲区必须沿序列维度分片以确保准确性。分片操作会就地执行,缓冲区的形状在上下文中会发生变化。上下文结束后,缓冲区会恢复原状。可以通过 no_restore_buffers 指定哪些缓冲区无需恢复。注意 buffers 不应包含任何 nn.Parameter。
  • buffer_seq_dims (Optional[List[int]])buffers 的序列维度。
  • no_restore_buffers (Optional[Set[torch.Tensor]]) – 此集合中的缓冲区在上下文退出后不会被恢复。该集合必须是 buffers 的子集。如果缓冲区在上下文退出后不再使用,可以将其加入此列表以避免额外的恢复时间。

返回类型

Generator

警告:torch.distributed._tensor.experimental.attention.context_parallel 是 PyTorch 中的原型功能。API 可能会发生变化。


torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)

local_map() 是一个实验性 API,允许用户将 DTensor 传递给原本设计用于处理 torch.Tensor 的函数。其实现原理是提取 DTensor 的本地分量,调用目标函数,然后根据 out_placements 将输出重新封装为 DTensor

参数说明

  • func (Callable) – 需要应用于每个 DTensor 本地分片的函数
  • out_placements (Union [PlacementType, Tuple[PlacementType, …]]) – 函数展平输出中 DTensor 的目标分布位置:
    • 当展平输出为单个值时,out_placements 应为 PlacementType 类型
    • 当展平输出包含多个值时,out_placements 应为与输出值一一对应的 PlacementType 元组
    • 对于 Tensor 输出,使用 PlacementType 作为其分布位置(即 Tuple[Placement] 值)
    • 对于非 Tensor 输出,PlacementType 应为 None

注意:当没有传入 DTensor 参数时,即使 out_placements 不为 None,结果函数也应忽略目标分布位置,因为此时函数并非运行在 DTensor 上。

  • in_placements (Tuple[PlacementType, …], optional) – 函数展平输入中 DTensor 的必需分布位置:
    • 指定时,local_map() 会检查每个 DTensor 参数的分布位置是否符合要求
    • 当分布位置不符且 redistribute_inputs=False 时会抛出异常
    • redistribute_inputs=True 时,参数会先重分布到要求的分片位置再传递给函数
    • 例外情况:当必需分布位置非 None 但参数是 torch.Tensor 时,跳过分布检查直接传递参数
    • 默认值:None
  • device_mesh (DeviceMesh, optional) – 所有 DTensor 所处的设备网格。未指定时从输入 DTensor 的设备网格推断。要求所有 DTensor 必须位于同一设备网格。默认值:None
  • redistribute_inputs ([bool], optional) – 布尔值,指示当输入 DTensor 分布位置与要求不符时是否进行重分布:
    • 为 False 且需要重分布时会抛出异常
    • 默认值:False

返回值
返回一个可调用对象,该对象会将 func 应用于输入 DTensor 的每个本地分片,并将函数返回值构造成 DTensor

异常情况

  • AssertionError – 当出现以下情况时触发:
    • 输入 DTensor 不在同一设备网格
    • 输入 DTensordevice_mesh 参数指定的设备网格不同
    • 非 Tensor 输出对应的 out_placements 不为 None
  • ValueError – 当 redistribute_inputs=False 但输入 DTensor 需要根据 in_placements 重分布时触发

示例:

>>> def mm_allreduce_forward(device_mesh, W, X):
>>>     partial_sum_tensor = torch.mm(W, X)
>>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>>     return reduced_tensor
>>> 
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)]  # 在一维网格上的行分片布局
>>> col_wise = [Shard(1)]  # 在一维网格上的列分片布局
>>> 
>>> # local_mm_allreduce_forward是封装了DTensor/Tensor转换的函数
>>> local_mm_allreduce_forward = local_map(
>>>     mm_allreduce_forward,
>>>     out_placements=[Replicate()],
>>>     in_placements=[col_wise, row_wise],
>>>     device_mesh=device_mesh,
>>> )
>>> 
>>> W_dt = distribute_tensor(
...     W, device_mesh, (col_wise)
... )  # 列分片的W张量
>>> X_dt = distribute_tensor(
...     X, device_mesh, (row_wise)
... )  # 行分片的X张量
>>> Y_dt = local_mm_allreduce_forward(
...     device_mesh, W_dt, X_dt
... )  # 对DTensors应用local_mm_allreduce_forward

注意:此 API 目前处于实验阶段,可能会发生变化


torch.distributed.tensor.experimental.register_sharding(op)

register_sharding() 是一个实验性 API,允许用户在张量输入输出为 DTensor 时,为运算符注册分片策略。

该 API 在以下场景中特别有用:(1) 当 op 不存在默认分片策略时(例如 opDTensor 不支持的定制运算符);(2) 当用户希望覆盖现有运算符的默认分片策略时。

参数说明

  • op (Union[OpOverload*, List[OpOverload]]) —— 需要注册自定义分片函数的单个运算符或运算符列表。

返回值

返回一个函数装饰器,可用于包装定义 op 所指定运算符分片策略的函数。定义的分片策略将被注册到 DTensor 中,若 DTensor 已实现该运算符,则会覆盖其默认分片策略。自定义分片函数的输入参数与原运算符相同(若参数为 torch.Tensor 则会被替换为 DTensor 内部使用的类张量对象)。该函数应返回由二元组构成的序列,每个二元组分别指定可接受的输出布局及其对应的输入布局。

使用示例


>>> @register_sharding(aten._softmax.default)
>>> def custom_softmax_sharding(x, dim, half_to_float):
>>>     softmax_dim = dim if dim >= 0 else dim + x.ndim
>>>     acceptable_shardings = []
>>> 
>>>     all_replicate = ([Replicate()], [Replicate(), None, None])
>>>     acceptable_shardings.append(all_replicate)
>>> 
>>>     for sharding_dim in range(x.ndim):
>>>         if sharding_dim != softmax_dim:
>>>             all_sharded = (
>>>                 [Shard(sharding_dim)], 
>>>                 [Shard(sharding_dim), None, None], 
>>>             )
>>>             acceptable_shardings.append(all_sharded)
>>> 
>>>     return acceptable_shardings

注意:此 API 目前处于实验阶段,后续可能会发生变化


torch.distributed.tensor


注意:torch.distributed.tensor 目前处于 alpha 开发阶段,文档中列出的大部分 API 我们将确保向后兼容性,但必要时可能会进行 API 变更。


PyTorch DTensor(分布式张量)

PyTorch DTensor 提供简单灵活的张量分片原语,能够透明处理分布式逻辑,包括跨设备/主机的分片存储、算子计算和集合通信。DTensor 可用于构建不同的并行解决方案,并支持在多维分片场景下表示分片状态的 state_dict。

以下是基于 DTensor 构建的 PyTorch 原生并行方案示例:

  • 张量并行
  • FSDP2

DTensor 遵循 SPMD(单程序多数据)编程模型,让用户能够像编写具有相同收敛特性的单设备程序那样编写分布式程序。它通过指定 DeviceMeshPlacement 提供统一的张量分片布局(DTensor 布局):

  • DeviceMesh 使用 n 维数组表示集群的设备拓扑和通信器
  • Placement 描述逻辑张量在 DeviceMesh 上的分片布局
    DTensor 支持三种分片类型:Shard(分片)、Replicate(复制)和 Partial(部分)。

DTensor 类 API

DTensortorch.Tensor 的子类。这意味着一旦创建了 DTensor,就可以以与 torch.Tensor 非常相似的方式使用它,包括运行不同类型的 PyTorch 操作符,就像在单个设备上运行它们一样,同时为 PyTorch 操作符提供正确的分布式计算支持。

除了现有的 torch.Tensor 方法外,它还提供了一组额外的方法来与 torch.Tensor 交互、将 DTensor 布局重新分配到新的 DTensor、获取所有设备上的完整张量内容等。


class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad) 

DTensor (Distributed Tensor) is a subclass of torch.Tensor that provides single-device like
abstraction to program with multi-device torch.Tensor. It describes the distributed tensor sharding
layout (DTensor Layout) through the DeviceMesh and following types of Placement:

  • Shard: Tensor sharded on the tensor dimension dim on the devices of the DeviceMesh dimension
  • Replicate: Tensor replicated on the devices of the DeviceMesh dimension
  • Partial: Tensor is pending reduction on the devices of the DeviceMesh dimension

When calling PyTorch operators, DTensor overrides the PyTorch operators to perform sharded computation and issue
communications whenever necessary. Along with the operator computation, DTensor will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate new DTensor outputs.

To ensure numerical correctness of the DTensor sharded computation when calling PyTorch operators, DTensor
requires every Tensor argument of the operator be DTensor.


Note: Directly using the Tensor subclass constructor here is not the recommended way to create a DTensor
(i.e. it does not handle autograd correctly hence is not the public API). Please refer to the create_dtensor
section to see how to create a DTensor.

Return type
DTensor


__create_chunk_list__()

返回一个 ChunkStorageMetadata 列表,该数据类用于描述当前 rank 上本地分片/副本的大小和偏移量。对于 DTensor,每个 rank 只会有一个本地分片/副本,因此返回的列表通常仅包含一个元素。

此双下划线方法主要用于分布式检查点用途。

返回值:一个 List[ChunkStorageMetadata] 对象,表示当前 rank 上的分片大小/偏移量。


static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)

根据指定的 device_meshplacements,从各 rank 上的本地 torch.Tensor 创建一个 DTensor

参数

  • local_tensor (torch.Tensor) – 各 rank 上的本地 torch.Tensor。
  • device_mesh (DeviceMesh, 可选) – 用于放置张量的 DeviceMesh。若未指定,则必须在 DeviceMesh 上下文管理器中调用,默认值:None
  • placements (List[Placement], 可选) – 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的布局列表,其元素数量必须与 device_mesh.ndim 相同。

关键字参数

  • run_check ([bool], 可选) – 以额外通信为代价,跨 rank 执行完整性检查,验证各本地张量的元信息以确保正确性。若 placements 中包含 Replicate,设备网格维度的第一个 rank 上的数据将被广播到其他 rank。默认值:False
  • shape ( torch.Size , 可选) – 指定构建在 local_tensor 之上的 DTensor 大小的整型列表。注意:当各 rank 上 local_tensor 的形状不同时必须提供此参数。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算 shape。默认值:None
  • stride ( tuple , 可选) – 指定 DTensor 步长的整型列表。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算 stride。默认值:None

返回

一个 DTensor 对象

返回类型:DTensor

注意:当 run_check=False 时,用户需自行确保传入的本地张量在各 rank 间正确(即对于 Shard(dim) 布局张量需分片,对于 Replicate() 布局需复制)。否则,所创建 DTensor 的行为将是未定义的。

注意:from_local 是可微操作,创建的 DTensor 对象的 requires_grad 属性将取决于 local_tensor 是否 requires_grad。


full_tensor(*, grad_placements=None) 

Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate
them together. It’s a syntatic sugar of the following code:

dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()

Keyword Arguments

  • grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function.
    full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor
    might not be used as the original replicated DTensor layout later in the code. This
    argument is the hint that user can give to autograd in case the gradient
    layout of the returned tensor does not match the original replicated DTensor layout.
    If not specified, we will assume the gradient layout of the full tensor be replicated.

Returns
A torch.Tensor object that represents the full tensor of this DTensor.

Return type : Tensor


Note: full_tensor is differentiable.


redistribute(device_mesh=None, placements=None, *, async_op=False)

redistribute 执行必要的集体操作,将当前 DTensor 从其现有布局重新分配到新布局,或从当前 DeviceMesh 迁移到新 DeviceMesh。例如,我们可以通过为 DeviceMesh 的每个维度指定 Replicate 布局,将分片(Sharded)DTensor 转换为复制(Replicated)DTensor。

当在 DeviceMesh 的某个维度上从当前布局重新分配到新布局时,将执行以下包含通信集体操作或本地操作:

1、Shard(dim)Replicate()all_gather
2、Shard(src_dim)Shard(dst_dim)all_to_all
3、Replicate()Shard(dim):本地分块(即 torch.chunk
4、Partial()Replicate()all_reduce
5、Partial()Shard(dim)reduce_scatter

redistribute 能够正确推断出针对在 1-D 或 N-D DeviceMesh 上创建的 DTensor 所需的重新分配步骤。

参数

  • device_mesh (DeviceMesh, 可选) – 用于放置 DTensor 的 DeviceMesh。若未指定,则使用当前 DTensor 的 DeviceMesh。
    默认值:None
  • placements (List[Placement], 可选) – 描述如何将 DTensor 放置到 DeviceMesh 中的新布局,其元素数量必须与 device_mesh.ndim 相同。
    默认值:在所有网格维度上复制(replicate)

关键字参数

  • async_op ([bool], 可选) – 是否以异步方式执行 DTensor 重新分配操作。默认值:False

返回
一个 DTensor 对象

返回类型
DTensor

注意redistribute 是可微分的,这意味着用户无需担心重新分配操作的反向传播公式。

注意redistribute 当前仅支持在同一 DeviceMesh 上重新分配 DTensor。若需将 DTensor 重新分配到不同 DeviceMesh,请提交问题。


to_local(*, grad_placements=None) 

Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.

Keyword Arguments

  • grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function.
    to_local converts DTensor to local tensor and the returned local tensor
    might not be used as the original DTensor layout later in the code. This
    argument is the hint that user can give to autograd in case the gradient
    layout of the returned tensor does not match the original DTensor layout.
    If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.

Returns
A torch.Tensor or AsyncCollectiveTensor object. it represents the local tensor on its current rank. When an AsyncCollectiveTensor object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to call wait to wait the local tensor to be ready.

Return type : Tensor


Note: to_local is differentiable, the requires_grad of the local tensor returned
will depend on if the DTensor requires_grad or not.


property device_mesh: [DeviceMesh](distributed.html#torch.distributed.device_mesh.DeviceMesh "torch.distributed.device_mesh.DeviceMesh")

The DeviceMesh attribute that associates with this DTensor object.

Note: device_mesh is a read-only property, it can not be set.


property placements:  tuple [[torch.distributed.tensor.placement_types.Placement](https://pytorch.org/docs/stable/data.html#torch.distributed.tensor.placement_types.Placement "torch.distributed.tensor.placement_types.Placement"),...]

该 DTensor 的 placements 属性描述了其在设备网格(DeviceMesh)上的分布布局。

注意placements 是只读属性,不可被修改。


作为分布式通信器的DeviceMesh

DeviceMesh基于DTensor构建,用于抽象描述集群设备拓扑结构,并作为多维通信器(基于ProcessGroup)的载体。如需了解如何创建/使用DeviceMesh的具体细节,请参阅DeviceMesh使用指南。


DTensor 布局类型

DTensor 支持在每个 DeviceMesh 维度上使用以下 Placement 类型:

class torch.distributed.tensor.placement_types.Shard(dim)

Shard(dim)布局描述了张量在维度dim上跨对应DeviceMesh维度的分片方式,其中DeviceMesh维度上的每个rank仅持有全局张量的一个分片。Shard(dim)布局遵循torch.chunk(dim)语义——当张量维度无法在DeviceMesh维度上均匀划分时,DeviceMesh维度上的最后几个分片可能为空。所有DTensor API(如distribute_tensorfrom_local等)均可使用Shard布局。

参数

  • dim (int) - 指定张量在对应DeviceMesh维度上进行分片的维度编号。

警告:当前对无法在DeviceMesh维度上均匀划分的张量维度进行分片属于实验性功能,后续可能变更。

dim: int


class torch.distributed.tensor.placement_types.Replicate

The Replicate() placement describes the DTensor replicating on a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. The Replicate placement can be used by all DTensor APIs (i.e. distribute_tensor, DTensor.from_local, etc.)


class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')

The Partial(reduce_op) placement describes the DTensor that is pending reduction on a specified DeviceMesh dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute the Partial DTensor to a Replicate or Shard(dim)
placement on the specified DeviceMesh dimension using redistribute, which would trigger necessary communication operations under the hood (i.e. allreduce, reduce_scatter).


Parameters

  • reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.

Note: The Partial placement can be generated as a result of the DTensor operators, and can only be used by the DTensor.from_local API.

reduce_op: str = 'sum'

class torch.distributed.tensor.placement_types.Placement

Placement 类型的基类,用于描述如何将 DTensor 放置在 DeviceMesh 上。PlacementDeviceMesh 共同定义了 DTensor 的布局。

它是三种主要 DTensor 放置类型(ShardReplicatePartial)的基类。

这个类不直接使用,主要作为类型标注存根。


is_partial() 

Return type : bool


is_replicate()

返回类型:bool


is_shard(dim=None) 

Return type : bool


Different ways to create a DTensor

There’re three ways to construct a DTensor😗 distribute_tensor() creates a DTensor from a logical or “global” torch.Tensor on each rank. This could be used to shard the leaf torch.Tensor s (i.e. model parameters/buffers and inputs).

  • DTensor.from_local() creates a DTensor from a local torch.Tensor on each rank, which can be used to create DTensor from a non-leaf torch.Tensor s (i.e. intermediate activation
    tensors during forward/backward).
  • DTensor provides dedicated tensor factory functions (e.g. empty(), ones(), randn(), etc.) to allow different DTensor creations by directly specifying the DeviceMesh and Placement. Compare to distribute_tensor(), this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.

Create DTensor from a logical torch.Tensor

The SPMD (single program, multiple data) programming model in torch.distributed launches multiple processes
(i.e. via torchrun) to execute the same program, this means that the model inside the program would be initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly on GPU if enough memory).

DTensor offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor s, where it would create a DTensor from the “logical” Tensor on each process. This would empower the created
DTensor s to comply with the single device semantic, which is critical for numerical correctness.


torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)

Distribute a leaf torch.Tensor (i.e. nn.Parameter/buffers) to the device_mesh according to the placements specified. The rank of device_mesh and placements must be the same. The tensor to distribute is the logical or “global” tensor, and the API would use the tensor from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd
computation, please use DTensor.from_local() instead.


Parameters

  • tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use torch.chunk
    semantic to shard the tensor and scatter the shards. The uneven sharding
    behavior is experimental and subject to change.
  • device_mesh (DeviceMesh, optional) – DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context
    manager, default: None
  • placements (List[Placement], optional) – the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements as device_mesh.ndim. If not specified, we will by default replicate the tensor across the device_mesh from the first rank of each dimension of the device_mesh.

Keyword Arguments

  • src_data_rank ( int , optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. by default, we use group_rank=0 on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None explicitly, distribute_tensor() simply uses
    its local data instead of trying to preserve the single-device semantic via scatter/broadcast.
    Default: 0

Returns
A DTensor or XLAShardedTensor object.

Return type
DTensor


Note: When initialize the DeviceMesh with the xla device_type, distribute_tensor
return XLAShardedTensor instead. see this issuefor more details. The XLA integration is experimental and subject to change.

Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier
sharding on the nn.Module level


torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)

该函数提供了三个功能来控制模块的参数/输入/输出:

1、通过在运行时执行前指定 partition_fn 对模块进行分片处理(即允许用户根据指定的 partition_fn 将模块参数转换为 DTensor 参数)。

2、通过在运行时执行时指定 input_fnoutput_fn 来控制模块的输入或输出(即将输入转换为 DTensor,将输出转换回 torch.Tensor)。

参数

  • module (nn.Module) – 需要分片的用户模块。
  • device_mesh (DeviceMesh) – 用于放置模块的设备网格。
  • partition_fn (Callable) – 用于分片参数的函数(即在 device_mesh 上切分特定参数)。如果未指定 partition_fn,默认会在网格上复制 module 的所有模块参数。
  • input_fn (Callable) – 指定输入分布,即可以控制模块输入的切分方式。input_fn 会作为模块的 forward_pre_hook(前向钩子)安装。
  • output_fn (Callable) – 指定输出分布,即可以控制输出的切分方式,或将其转换回 torch.Tensor。output_fn 会作为模块的 forward_hook(后向钩子)安装。

返回

一个包含所有参数/缓冲区的模块,这些参数/缓冲区均为 DTensor 类型。

返回类型:Module

注意:当使用 xla 设备类型初始化 DeviceMesh 时,distribute_module 会返回带有 PyTorch/XLA SPMD 注释参数的 nn.Module。详情请参阅此问题。XLA 集成目前处于实验阶段,可能会发生变化。


DTensor 工厂函数

DTensor 还提供了专门的张量工厂函数,允许直接创建 DTensor。这些函数使用类似 torch.Tensor 的工厂函数 API(例如 torch.ones、torch.empty 等),并通过额外指定 DeviceMeshPlacement 来配置所创建的 DTensor


torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None) 

Returns a DTensor filled with the scalar value 0.


Parameters

  • size ( int *...) – a sequence of integers defining the shape of the output DTensor.
    Can be a variable number of arguments or a collection like a list or tuple.
    E.g.: zeros(1,2,3…) or zeros([1,2,3…]) or zeros((1,2,3…))

Keyword Arguments

  • requires_grad ([bool], optional) – If autograd should record operations on the returned DTensor. Default: False.
  • dtype ( torch.dtype , optional) – the desired data type of returned DTensor.
    Default: if None, uses a global default (see torch.set_default_dtype()).
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), optional) – the desired layout of returned DTensor.
    Default: torch.strided.
  • device_meshDeviceMesh type, contains the mesh info of ranks
  • placements – a sequence of Placement type: Shard, Replicate

Returns
A DTensor object on each rank

Return type
DTensor


torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)

返回一个填充了标量值1的DTensor,其形状由可变参数size定义。

参数

  • size ( int *...) – 定义输出DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))

关键字参数

  • dtype ( torch.dtype , 可选) – 返回DTensor的期望数据类型。默认值:如果为None,则使用全局默认值(参见torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) – 返回DTensor的期望布局。默认值:torch.strided
  • requires_grad ([bool], 可选) – 是否应在返回的DTensor上记录自动梯度操作。默认值:False
  • device_meshDeviceMesh类型,包含进程的网格信息
  • placementsPlacement类型的序列:ShardReplicate

返回值:每个进程上的一个DTensor对象

返回类型:DTensor


torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None) 

Returns a DTensor filled with uninitialized data. The shape of the DTensor is defined by the variable argument size.


Parameters

  • size ( int *...) – a sequence of integers defining the shape of the output DTensor.
    Can be a variable number of arguments or a collection like a list or tuple.
    E.g.: empty(1,2,3…) or empty([1,2,3…]) or empty((1,2,3…))

Keyword Arguments

  • dtype ( torch.dtype , optional) – the desired data type of returned DTensor.
    Default: if None, uses a global default (see torch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returned DTensor.
    Default: torch.strided.
  • requires_grad ([bool], optional) – If autograd should record operations on the returned DTensor. Default: False.
  • device_meshDeviceMesh type, contains the mesh info of ranks
  • placements – a sequence of Placement type: Shard, Replicate

Returns
A DTensor object on each rank

Return type
DTensor


torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)

根据 device_meshplacements 参数,返回一个填充了 fill_valueDTensor,其形状由参数 size 定义。

参数

  • size ( int *...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数,也可以是列表或元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))。
  • fill_value (Scalar) – 用于填充输出张量的值。

关键字参数

  • dtype ( torch.dtype , 可选) – 返回的 DTensor 所需的数据类型。默认值:如果为 None,则使用全局默认值(参见 torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) – 返回的 DTensor 所需的布局。默认值:torch.strided
  • requires_grad ([bool], 可选) – 是否应自动梯度记录对返回的 DTensor 的操作。默认值:False
  • device_meshDeviceMesh 类型,包含 rank 的网格信息。
  • placementsPlacement 类型的序列:ShardReplicate

返回

每个 rank 上的一个 DTensor 对象。

返回类型

DTensor


torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None) 

Returns a DTensor filled with random numbers from a uniform distribution on the interval [0, 1). The shape of the tensor is defined by the variable
argument size.


Parameters

  • size ( int *...) – a sequence of integers defining the shape of the output DTensor.
    Can be a variable number of arguments or a collection like a list or tuple.
    E.g.: ones(1,2,3…) or ones([1,2,3…]) or ones((1,2,3…))

Keyword Arguments

  • dtype ( torch.dtype , optional) – the desired data type of returned DTensor.
    Default: if None, uses a global default (see torch.set_default_dtype()).
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), optional) – the desired layout of returned DTensor.
    Default: torch.strided.
  • requires_grad ([bool], optional) – If autograd should record operations on the returned DTensor. Default: False.
  • device_meshDeviceMesh type, contains the mesh info of ranks.
  • placements – a sequence of Placement type: Shard, Replicate

Returns
A DTensor object on each rank

Return type
DTensor


torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)

返回一个填充了均值为0、方差为1的正态分布随机数的DTensor,张量的形状由变量参数size定义。

参数

  • size (int *...) - 定义输出DTensor形状的整数序列。可以是可变数量的参数或列表/元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))

关键字参数

  • dtype (torch.dtype, 可选) - 返回DTensor的期望数据类型。默认值:如果为None,则使用全局默认值(参见torch.set_default_dtype())。
  • layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选) - 返回DTensor的期望布局。默认值:torch.strided
  • requires_grad ([bool], 可选) - 是否应在返回的DTensor上记录自动求导操作。默认值:False
  • device_mesh - DeviceMesh类型,包含rank的网格信息。
  • placements - Placement类型的序列:ShardReplicate

返回

每个rank上的一个DTensor对象

返回类型

DTensor


调试


日志记录

启动程序时,可以通过设置 torch._logging 中的 TORCH_LOGS 环境变量来启用额外的日志记录功能:

  • TORCH_LOGS=+dtensor 将显示 logging.DEBUG 及以上级别的日志消息
  • TORCH_LOGS=dtensor 将显示 logging.INFO 及以上级别的日志消息
  • TORCH_LOGS=-dtensor 将显示 logging.WARNING 及以上级别的日志消息

调试工具

为了调试应用了DTensor的程序,并深入了解底层发生的集合通信细节,DTensor提供了CommDebugMode调试模式:

class torch.distributed.tensor.debug.CommDebugMode 

CommDebugMode is a context manager that counts the number of functional collectives within its context. It does this using a TorchDispatchMode.


Note: Not all collectives are supported yet.


Example usage

mod = ...
comm_mode = CommDebugMode()
with comm_mode:mod.sum().backward()
print(comm_mode.get_comm_counts())

generate_comm_debug_tracing_table(noise_level=3)

生成详细表格,展示模块层级的操作和集体追踪信息。信息量取决于 noise_level 参数:

0、打印模块层级的集体调用次数统计

1、打印未包含在简单操作中的 dTensor 操作及模块信息

2、打印未包含在简单操作中的所有操作

3、打印全部操作


generate_json_dump(file_name='comm_mode_log.json', noise_level=3) 

Creates json file used to build browser visual
0、prints module-level collective counts
1、prints dTensor operations not included in trivial operations
2、prints operations not included in trivial operations
3、prints all operations


get_comm_counts()

返回通信计数作为字典。

返回值:以字典形式返回通信计数。

返回类型:Dict[Any, int]


get_parameter_info() 

Return type : dict[str , dict[str , Any ]


get_sharding_info()

返回类型 : dict[str, dict[str, Any]]


get_total_counts()

Return type : int


log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)

替代控制台 CommDebugMode 输出的方案,可将日志写入用户指定的文件

为了可视化维度少于 3 的 DTensor 分片情况,DTensor 提供了 visualize_sharding() 方法:

torch.distributed.tensor.debug.visualize_sharding(dtensor, header='') 

Visualizes sharding in the terminal for DTensor that are 1D or 2D.


Note: This requires the tabulate package. No sharding info will be printed for empty tensors


Experimental Features

DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basic
functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to these features.


torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)

context_parallel is an experimental API to enable context
parallelism (CP). This API performs two actions: 1) patch the SDPA
(torch.nn.functional.scaled_dot_product_attention) with the CP-enabled
one, 2) shard buffers along the sequence dimension and each rank will
preserve the corresponding shard according mesh.


Parameters

  • mesh (DeviceMesh) – the device mesh for the context parallelism.
  • buffers (Optional[List[torch.Tensor]]) – buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will
    happen in-place, the buffer’s shape will change within the context.
    The buffers will be restored after the context finishes.
    no_restore_buffers can be used to specify which buffers don’t
    need to be restored. Note that buffers should not contain any
    nn.Parameter.
  • buffer_seq_dims (Optional[List[int]]) – the sequence dimensions of buffers.
  • no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these set
    won’t be restored after the context exits. This set must be a subset of buffers. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time.

Return type
Generator


Warning: torch.distributed._tensor.experimental.attention.context_parallel is a prototype feature in PyTorch. The API is subject to change.


torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)

local_map() is an experimental API that allows users to pass DTensor s to a function that is written to be applied on torch.Tensor s. It is done by extracting the local components of DTensor, call the function, and wrap the outputs to DTensor according to the out_placements.


Parameters

  • func (Callable) – the function to be applied on each local shard of DTensor s.
  • out_placements (Union [PlacementType, Tuple[PlacementType, …]]) – the desired placements of the DTensor s in func’s flattened output.
    If the flattened output is a single value, the out_placements should be of type PlacementType. Otherwise if the flattened output has multiple
    values, the out_placements should be a tuple of PlacementType values 1:1
    mapping to the flattened output.
    Besides, for Tensor output, we use PlacementType as its
    placements (a Tuple[Placement] value). For non-Tensor output, the PlacementType
    should be None.
    Note that the only exception is when no DTensor argument is passed
    in. In this case, even if out_placements is not None, the result function
    should ignore the desired placements because the function is not running with DTensor s.
  • in_placements (Tuple[PlacementType, …], optional) – the required placements of the DTensor s in the flattened inputs of func.
    If in_placements is specified, local_map() would examine whether the placements of each DTensor argument is the same as the required
    placements or not. If the placements are not the same and redistribute_inputs is False, an exception will be raised. Otherwise if redistribute_inputs is True, the argument will be first redistributed to the required sharding placements before passing its local tensor to func.
    The only exception is when required placements are not None and the argument is a torch.Tensor . In this case, the placements examination
    will be skipped and the argument will be directly passed to func.
    If in_placements is None, no placements examination will be performed.
    Default: None
  • device_mesh (DeviceMesh, optional) – the device mesh that all the DTensor s are placed on. If not
    specified, this will be inferred from the input DTensor s’ device
    mesh. local_map requires every DTensor s to be placed on the same device mesh. Default: None.
  • redistribute_inputs ([bool], optional) – the bool value indicating whether to reshard the input DTensor s when
    their placements are different from the required input placements. If this value is False and some DTensor input has a different placement, an exception will be raised. Default: False.

Returns
A Callable that applies func to each local shard of the input DTensor and returns a DTensor constructed from the return value of func.

Raises

  • AssertionError – If the input DTensor is not placed on the same device
    mesh, or if they are placed on a different device mesh than the device_mesh
    argument passed in.
  • AssertionError – For any non-DTensor output, we require its corresponding
    output placement in out_placements be None. An AssertionError will be raised
    if this is not the case.
  • ValueError – If redistribute_inputs=False but the input DTensor needs
    a redistribution according to in_placements.

Example :

>>> def mm_allreduce_forward(device_mesh, W, X):
>>>     partial_sum_tensor = torch.mm(W, X)
>>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>>     return reduced_tensor
>>> 
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)]  # 在一维网格上的行分片布局
>>> col_wise = [Shard(1)]  # 在一维网格上的列分片布局
>>> 
>>> # local_mm_allreduce_forward是封装了DTensor/Tensor转换的函数
>>> local_mm_allreduce_forward = local_map(
>>>     mm_allreduce_forward,
>>>     out_placements=[Replicate()],
>>>     in_placements=[col_wise, row_wise],
>>>     device_mesh=device_mesh,
>>> )
>>> 
>>> W_dt = distribute_tensor(
...     W, device_mesh, (col_wise)
... )  # 列分片的W张量
>>> X_dt = distribute_tensor(
...     X, device_mesh, (row_wise)
... )  # 行分片的X张量
>>> Y_dt = local_mm_allreduce_forward(
...     device_mesh, W_dt, X_dt
... )  # 对DTensors应用local_mm_allreduce_forward

Note: This API is currently experimental and subject to change

torch.distributed.tensor.experimental.register_sharding(op)

register_sharding() is an experimental API that allows users to register sharding
strategies for an operator when the tensor inputs and outputs are DTensor.
It can be useful when: (1) there doesn’t exist a default sharding strategy for op, e.g. when op is a custom operator that is not supported by DTensor; (2)
when users would like to overwrite default sharding strategies of existing operators.


Parameters

  • op (Union[OpOverload*, List[OpOverload]]) – An op or a list of ops to register the customized sharding function.

Returns
A function decorator which can be used to wrap a function that defines the sharding
strategy for the operator specified in op. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is a torch.Tensor , it will be replaced by a tensor-like object that DTensor uses internally). The function should
return a sequence of 2-tuples, each specifying acceptable output placements and its
corresponding intput placements.


Example:

>>> @register_sharding(aten._softmax.default)
>>> def custom_softmax_sharding(x, dim, half_to_float):
>>>     softmax_dim = dim if dim >= 0 else dim + x.ndim
>>>     acceptable_shardings = []
>>> 
>>>     all_replicate = ([Replicate()], [Replicate(), None, None])
>>>     acceptable_shardings.append(all_replicate)
>>> 
>>>     for sharding_dim in range(x.ndim):
>>>         if sharding_dim != softmax_dim:
>>>             all_sharded = (
>>>                 [Shard(sharding_dim)], 
>>>                 [Shard(sharding_dim), None, None], 
>>>             )
>>>             acceptable_shardings.append(all_sharded)
>>> 
>>>     return acceptable_shardings

注意:此 API 目前处于实验阶段,后续可能会发生变化


通用Join上下文管理器

通用Join上下文管理器用于简化不均匀输入的分布式训练。本文档概述了相关类的API:JoinJoinableJoinHook。如需教程,请参阅使用Join上下文管理器进行不均匀输入的分布式训练。


class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)

该类定义了通用的join上下文管理器,允许在进程加入后调用自定义钩子。

这些钩子应屏蔽未加入进程的集体通信,以防止挂起和错误,并确保算法正确性。有关钩子定义的详细信息,请参阅JoinHook

警告:上下文管理器要求每个参与的Joinable在其每次迭代的集体通信之前调用notify_join_context()方法以确保正确性。

警告:上下文管理器要求JoinHook对象中的所有process_group属性必须相同。如果存在多个JoinHook对象,则使用第一个对象的device

进程组和设备信息用于检查未加入的进程,并在启用throw_on_early_termination时通知进程抛出异常,这两者都使用all-reduce操作。

参数

  • joinables (List[Joinable ]) – 参与的Joinable列表;将按给定顺序迭代它们的钩子。
  • enable ([bool]) – 启用不均匀输入检测的标志;设置为False将禁用上下文管理器的功能,仅当用户确认输入不会不均匀时才应设置(默认值:True)。
  • throw_on_early_termination ([bool]) – 控制检测到不均匀输入时是否抛出异常的标志(默认值:False)。

示例:

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>> >
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring

static notify_join_context(joinable)

通知连接上下文管理器,调用进程尚未加入。

如果设置了 throw_on_early_termination=True,则会检查是否检测到输入不均衡
(即是否有进程已提前加入),若存在则抛出异常。

此方法应在 Joinable 对象执行每次迭代的集合通信前调用。
例如,在 DistributedDataParallel 的前向传播开始时应当调用此方法。

只有传入上下文管理器的第一个 Joinable 对象会执行此方法中的集合通信,
其余对象调用此方法时不执行实际操作。

参数

  • joinable (Joinable) – 调用此方法的 Joinable 对象。

返回值
若当前 joinable 是传入上下文管理器的首个对象,则返回一个异步工作句柄,
用于通过全减操作通知上下文管理器该进程尚未加入;否则返回 None


class torch.distributed.algorithms.Joinable

这里定义了一个可加入类的抽象基类。

一个可加入类(继承自 Joinable)需要实现以下内容:

  • join_hook() 方法,返回一个 JoinHook 实例
  • join_device() 方法,返回设备信息
  • join_process_group() 方法,返回进程组信息

ABSTRACT PROPERTY join_device:  device 

返回用于执行由 join 上下文管理器所需的集体通信的设备。


ABSTRACT  join_hook(**kwargs)

为给定的 Joinable 返回一个 JoinHook 实例。

参数

  • kwargs (dict) - 包含运行时修改 join hook 行为的关键字参数字典;所有共享相同 join 上下文管理器的 Joinable 实例都会收到相同的 kwargs 值。

返回类型:JoinHook

ABSTRACT PROPERTY join_process_group:  Any

返回连接上下文管理器本身所需的集体通信的进程组。


class torch.distributed.algorithms.JoinHook

这里定义了一个连接钩子(join hook),它在连接上下文管理器中提供了两个入口点:

入口点包括:
1、主钩子(main hook):当存在未连接的进程时会被重复调用
2、后置钩子(post-hook):当所有进程都完成连接后会被调用一次

要为通用连接上下文管理器实现连接钩子,需要定义一个继承自JoinHook的类,并根据需要重写main_hook()post_hook()方法。


main_hook()

在训练迭代中存在未加入的进程时调用此钩子,以跟踪集体通信。

训练迭代指的是:一次前向传播、反向传播和优化器步骤的过程。


post_hook(is_last_joiner)

在所有进程都加入后调用钩子。

该钩子会接收一个额外的 bool 类型参数 is_last_joiner,用于指示当前 rank 是否属于最后一批加入的进程。

参数

  • is_last_joiner ([bool]) – 当 rank 属于最后一批加入的进程时为 True;否则为 False

Torch Distributed Elastic

为分布式 PyTorch 提供容错与弹性能力。


快速开始

使用指南

  • 快速入门
  • 训练脚本
  • 示例

文档

API

  • torchrun (弹性启动)
  • 弹性代理
  • 多进程处理
  • 错误传播
  • 集合点
  • 过期计时器
  • 指标监控
  • 事件处理
  • 子进程处理
  • 控制平面

高级功能

  • 自定义配置

插件

  • TorchElastic Kubernetes

2025-05-10(六)

相关文章:

  • SpringCloud之Eureka基础认识-服务注册中心
  • 一、数据仓库基石:核心理论、分层艺术与 ETL/ELT 之辨
  • 第十七次博客打卡
  • MySQL 从入门到精通(六):视图全面详解 —— 虚拟表的灵活运用
  • vue开发用户注册功能
  • JVM 数据区域
  • 微服务6大拆分原则
  • Linux 下 Java 部署环境搭建与项目部署详细步骤
  • PyTorch 线性回归模型构建与神经网络基础要点解析
  • 【金仓数据库征文】学校AI数字人:从Sql Server到KingbaseES的数据库转型之路
  • 十六、统一建模语言 UML
  • cdn 是什么?
  • AIGC时代大模型幻觉问题深度治理:技术体系、工程实践与未来演进
  • LSP里氏替换原则
  • 全息美AISEO引领未来智能营销新趋势
  • 关键点检测--使用YOLOv8对Leeds Sports Pose(LSP)关键点检测
  • Kubernetes生产实战(十六):集群安全加固全攻略
  • 协议路由与路由协议
  • 数据库索引详解:原理 · 类型 · 使用 · 优化
  • 流式数据(Streaming Data)和非流式数据(Batch Data)区别、使用场景、优化-来自前端的浅解
  • 武汉旅游体育集团有限公司原党委书记、董事长董志向被查
  • 第19届威尼斯建筑双年展开幕,中国案例呈现“容·智慧”
  • 印控克什米尔地区再次传出爆炸声
  • 海航回应“男团粉丝为追星堵住机舱通道”:已紧急阻止
  • 外交部发言人就印巴局势升级答记者问
  • 印度一战机在巴基斯坦旁遮普省被击落,飞行员被俘