当前位置: 首页 > news >正文

PyTorch API 4

文章目录

  • torch.distributed.fsdp.fully_shard
    • PyTorch FSDP2 (`fully_shard`)
  • Tensor Parallelism - torch.distributed.tensor.parallel
  • 分布式优化器
  • 流水线并行
    • 为什么需要流水线并行?
    • 什么是 `torch.distributed.pipelining`?
    • 第一步:构建 `PipelineStage`
    • 步骤2:使用`PipelineSchedule`执行
    • 模型分割方案
      • 方案一:手动拆分模型
      • 选项2:自动拆分模型
    • Hugging Face 示例
    • 技术深度解析
      • `pipeline` API 如何分割模型?
    • 实现自定义调度策略
    • 日志记录
    • API 参考
      • 模型拆分 API
      • 微批次工具集
      • 流水线阶段
      • 流水线调度
  • 分布式检查点 - torch.distributed.checkpoint
    • 附加资源:
    • StorageReader 方法的实现
  • 概率分布 - torch.distributions
    • 评分函数
    • 路径导数
    • 分发
    • 指数族分布
    • 伯努利
    • Beta
    • Binomial
    • Categorical
    • Cauchy
    • Chi2
    • 连续伯努利分布
    • 狄利克雷
    • 指数函数
    • 费希尔-斯涅克分布
    • Gamma
    • 几何
    • gumbel
    • HalfCauchy
    • HalfNormal
    • 独立
    • InverseGamma
    • Kumaraswamy
    • LKJCholesky
    • Laplace
    • 对数正态分布
    • LowRankMultivariateNormal
    • 混合相同族分布
    • Multinomial
    • 多元正态分布
    • NegativeBinomial
    • 常规
    • OneHotCategorical
    • Pareto
    • Poisson
    • 松弛伯努利分布
    • LogitRelaxedBernoulli
    • RelaxedOneHotCategorical
    • StudentT 分布
    • TransformedDistribution
    • 统一性
    • 冯·米塞斯
    • weibull
    • wishart
    • KL Divergence
    • Transforms
    • Constraints
    • Constraint Registry


torch.distributed.fsdp.fully_shard


PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 提供了一种完全分片数据并行(FSDP)实现,旨在实现高性能的即时执行模式,同时采用逐参数分片以提升易用性。

  • 如果您是 FSDP 的新用户,我们建议从 FSDP2 开始使用,因其具有更好的易用性。
  • 如果您当前正在使用 FSDP1,请评估以下差异以决定是否应切换到 FSDP2:

与 PyTorch FSDP1 (FullyShardedDataParallel) 相比:

  • FSDP2 使用基于 DTensor 的维度 0 逐参数分片,相比 FSDP1 的扁平参数分片提供了更简单的分片表示,同时保持了相似的吞吐性能。具体来说,FSDP2 在数据并行工作节点间对每个参数沿维度 0 进行分块(使用 torch.chunk(dim=0)),而 FSDP1 会将一组张量展平、拼接并一起分块,这使得理解每个工作节点上的数据以及重新分片到不同并行模式变得复杂。逐参数分片提供了更直观的用户体验,放宽了对冻结参数的限制,并允许无通信(分片)的状态字典,而在 FSDP1 中则需要全收集操作。
  • FSDP2 采用不同的内存管理方法来处理多流使用场景,避免了 torch.Tensor.record_stream。这确保了确定性和预期的内存使用,且不需要像 FSDP1 的 limit_all_gathers=True 那样阻塞 CPU。
  • FSDP2 提供了手动控制预取和集体调度的 API,允许高级用户进行更多自定义。详情请参阅下文 FSDPModule 的方法。
  • FSDP2 简化了部分 API 接口:例如,FSDP2 不直接支持完整状态字典。用户可以使用 DTensor API(如 DTensor.full_tensor())或更高级的 API(如 PyTorch 分布式检查点 的分布式状态字典 API)自行将包含 DTensor 的分片状态字典重新分片为完整状态字典。此外,一些其他参数已被移除;详情请参阅此处。

如果您是首次使用 FSDP,或上述任何一点符合您的使用场景,我们建议您考虑使用 FSDP2。

有关系统设计和实现的详细信息,请参阅此 RFC。


注意:torch.distributed.fsdp.fully_shard 目前处于原型阶段,正在开发中。核心 API 可能不会更改,但我们可能会根据需要进行一些 API 调整。

前端 API 是 fully_shard,可以在 module 上调用:

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)

module应用全分片数据并行(FSDP),其中FSDP将模块参数、梯度和优化器状态分片到数据并行工作节点上,以通信开销为代价节省内存。

初始化时,FSDP根据mesh指定的数据并行工作节点对模块参数进行分片。在前向计算前,FSDP通过全收集操作跨数据并行工作节点获取完整参数用于计算。若reshard_after_forwardTrue,则FSDP在前向计算后释放完整参数,并在反向计算前重新全收集。梯度计算完成后,FSDP释放完整参数并通过规约分散操作分发未分片梯度。

本实现使用DTensor表示分片参数(沿0维分片),而完整参数保持与原始模块参数相同类型(如原为torch.Tensor则仍为torch.Tensor)。模块的前向预钩子负责参数全收集,前向钩子负责释放参数(如需要)。类似的反向钩子处理参数收集与梯度分发。

为提高通信效率,本实现将多个张量分组进行集合操作。对module调用fully_shard()会构建包含module.parameters()中参数的分组(子模块已分组参数除外),因此应在模型上自底向上调用fully_shard()。每组参数通过单次集合操作完成全收集和梯度规约分散。分层分组(“逐层”)可实现内存峰值优化和通信/计算重叠。通常不应仅在顶层模块调用fully_shard()

参数说明

  • module (Union[nn.Module, List[nn.Module]) – 待分片的模块或模块列表,这些模块将被分组进行通信
  • mesh (Optional[[DeviceMesh](distributed.html#torch.distributed.device_mesh.DeviceMesh "torch.distributed.device_mesh.DeviceMesh")]) – 数据并行网格定义分片方式和设备:
    • 一维网格:参数完全分片(FSDP),采用(Shard(0),)布局
    • 二维网格:参数沿第1维分片且沿第0维复制(HSDP),采用(Replicate(), Shard(0))布局
    • 网格设备类型决定通信设备类型(如CUDA类设备使用当前设备)
  • reshard_after_forward (Union[[bool],* int ]) – 控制前向计算后的参数行为,平衡内存与通信:
    • True:前向后重新分片参数,反向时重新全收集
    • False:前向后保留完整参数,避免反向全收集
    • 整数值:指定前向后的分片规模(应为网格分片维度的非平凡除数,如节点内设备数torch.cuda.device_count()),以较小通信规模换取较高内存使用
    • 根FSDP状态默认设为False(因其参数通常需立即用于反向计算)
    • 前向后模块注册参数取决于该设置:分片参数(True时)、完整参数(False时)或缩网格分片参数(整数值时)。若需在前反向间修改参数,必须注册分片参数(对False或整数值可通过reshard()手动分片)
  • shard_placement_fn (Optional[Callable[[nn.Parameter],* Optional[Shard ]]]) – 自定义参数分片布局(如返回Shard(1)则沿1维分片)。当前非零维分片要求张量维度大小可被分片网格大小整除
  • mp_policy ( MixedPrecisionPolicy ) – 混合精度策略,控制该模块的参数/规约精度。详见MixedPrecisionPolicy
  • offload_policy (OffloadPolicy) – 卸载策略,控制参数/梯度/优化器状态卸载。详见OffloadPolicy及其子类
  • ignored_params (Optional[set[nn.Parameter]]) – 不需要FSDP分片的参数集合

返回值:返回应用FSDP后的模块(原地修改),类型为FSDPModule

调用fully_shard(module)会动态创建继承原模块类型和FSDPModule的新类。例如对linear: nn.Linear调用fully_shard(linear)会生成FSDPLinear类并转换模块类型。该方法不改变模块结构和参数全限定名,FSDPModule类提供特定于FSDP的方法支持。


class torch.distributed.fsdp.FSDPModule(*args, **kwargs) reshard()

对模块参数进行重新分片,如果未分片参数已分配则释放它们,并将分片后的参数注册到模块中。该方法不会递归执行。


set_all_reduce_hook(hook, *, stream=None)

参数

  • hook (Callable[[torch.Tensor], None]) – 用户自定义的all-reduce钩子函数,预期签名为hook(reduce_output: torch.Tensor) -> None

其中reduce_output表示:

  • 若仅使用FSDP则为reduce-scatter操作的输出
  • 若使用原生HSDP则为all-reduce操作的输出
  • stream (Optional[torch.cuda.Stream]) – 运行all-reduce钩子的CUDA流。注意:
    • 仅在不使用原生HSDP时需要设置此参数
    • 若使用原生HSDP,钩子将在HSDP内部定义的all-reduce流中自动执行

set_is_last_backward(is_last_backward)

设置下一次反向传播是否为最后一次。在最后一次反向传播时,FSDP会等待待处理的梯度归约操作完成,并清除用于反向预取的内部数据结构。这一特性对于微批次训练特别有用。


set_modules_to_backward_prefetch(modules)

设置当前FSDP模块在反向传播时应显式预取全聚集操作的FSDP模块。这会覆盖默认的反向预取实现(默认实现基于反向后序顺序预取下一个FSDP模块)。

传入包含前一个FSDP模块的单例列表,可获得与默认重叠行为相同的全聚集操作重叠效果。

若需更激进的重叠效果(会占用更多预留内存),则必须传入至少包含两个模块的列表。

参数:
modules (List[FSDPModule]) – 需要预取的FSDP模块列表。


set_modules_to_forward_prefetch(modules)

设置此FSDP模块在正向传播中应显式预取全收集操作的FSDP模块。预取操作会在本模块的全收集复制输出之后执行。

如果传入仅包含下一个FSDP模块的单例列表,将获得与默认重叠行为相同的全收集重叠效果,区别在于预取的全收集操作会从CPU端更早发起。要实现更激进的重叠效果(将占用更多预留内存),必须传入至少包含两个模块的列表。

参数

modules (List[FSDPModule]) – 需要预取的FSDP模块列表。


set_post_optim_event(event)

为根FSDP模块设置一个优化器步骤后事件,用于等待所有聚集流就绪。

默认情况下,根FSDP模块会在当前流上等待所有聚集流,以确保优化器步骤在开始全聚集前已完成。但如果优化器步骤后存在无关计算,这种方式可能会引入虚假依赖。该API允许用户提供自定义事件进行等待。当根模块完成事件等待后,该事件会被丢弃,因此每次迭代都应调用本API传入新事件。

参数

event (torch.Event) - 记录在优化器步骤后、用于等待所有聚集流的事件对象。


set_reduce_scatter_divide_factor(factor)

为reduce-scatter操作设置自定义的除法因子。这将通过NCCL的PreMulSum功能实现一个自定义的reduce操作,允许在归约前先乘以该因子。

参数

factor (float) – 自定义除法因子。


set_requires_all_reduce(requires_all_reduce, *, recurse=True)

设置该模块是否应执行梯度全归约操作。这可用于实现仅使用reduce-scatter而不进行全归约的梯度累积方案,适用于HSDP场景。


set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)

设置模块是否应同步梯度。该功能可用于实现无需通信的梯度累积。对于HSDP,这将同时控制reduce-scatter和all-reduce操作。其功能等同于FSDP1中的no_sync。

参数说明

  • requires_gradient_sync ([bool]) – 控制是否对模块参数执行梯度归约操作。
  • recurse ([bool]) – 控制设置范围:仅作用于当前模块,还是递归作用于所有FSDP子模块。

set_reshard_after_backward(reshard_after_backward, *, recurse=True)

设置模块是否应在反向传播后重新分片参数。这在梯度累积期间可用于以更高内存为代价换取减少通信,因为在下一次前向传播前无需重新全收集未分片的参数。

参数

  • reshard_after_backward ([bool]) – 是否在反向传播后重新分片参数。
  • recurse ([bool]) – 是为所有FSDP子模块设置还是仅针对传入的模块设置。

set_unshard_in_backward(unshard_in_backward)

设置是否需要在反向传播时解除FSDP模块参数的共享状态。这一功能适用于专家级场景,当用户确定该FSDP模块参数组中的所有参数都不参与反向计算时(例如嵌入层),便可使用此设置。


unshard(async_op=False)

通过分配内存并全收集(all-gather)参数来解除模块参数的分片状态。此方法不会递归执行。解除分片操作遵循MixedPrecisionPolicy,因此如果设置了param_dtype,将按照该类型进行全收集。

参数

  • async_op ([bool]) - 若为True,则返回一个包含wait()方法的UnshardHandle对象用于等待解除分片操作;若为False,则返回None并在函数内部等待操作完成。

返回类型:Optional[UnshardHandle]

注意:当async_op=True时,FSDP会在模块的前向传播前自动等待待处理的解除分片操作。用户只需在需要前向传播前显式调用wait()方法即可。


class torch.distributed.fsdp.UnshardHandle 

一个用于等待 FSDPModule.unshard() 操作完成的句柄。


wait()

等待取消分片操作完成。这确保了当前流可以使用已取消分片的参数,这些参数现在已注册到模块中。


torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)

module 上注册一个方法,使其被视为 FSDP 的前向传播方法。

FSDP 会在前向传播前执行参数的全收集(all-gather),并可选地在前向传播后释放参数(取决于 reshard_after_forward 的设置)。默认情况下,FSDP 仅对 nn.Module.forward() 执行此操作。此函数通过钩子机制,使用户指定的方法分别在执行前后运行前向/后向传播的预处理/后处理逻辑。如果 module 不是 FSDPModule 实例,则该操作无效。

参数

  • module (nn.Module) – 需要注册前向传播方法的模块。
  • method_name (str) – 前向传播方法的名称。

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) 

该配置用于设置FSDP的混合精度。与autocast不同,这是在模块级别而非操作级别应用混合精度,意味着会保存低精度激活值用于反向传播,而高精度到低精度的转换仅发生在模块边界处。

FSDP与模块级混合精度配合良好,因为它始终在内存中保存高精度分片参数。换句话说,FSDP不需要额外内存来保存高精度参数副本用于优化器步骤。

变量说明:

  • param_dtype (Optional[torch.dtype]) - 指定未分片参数的数据类型,即前向/反向计算和参数全收集时使用的数据类型。若为None,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型的已分片参数。(默认值:None
  • reduce_dtype (Optional[torch.dtype]) - 指定梯度规约(即reduce-scatter或all-reduce)时使用的数据类型。若为Noneparam_dtype不为None,则规约使用计算数据类型。该参数可用于在计算时使用低精度,同时保持梯度规约为全精度。若通过set_requires_gradient_sync()禁用梯度规约,FSDP将使用reduce_dtype累积梯度。(默认值:None
  • output_dtype (Optional[torch.dtype]) - 指定浮点前向输出结果的转换数据类型。可用于实现不同模块采用不同混合精度策略的场景。(默认值:None
  • cast_forward_inputs ([bool]) - 指定FSDP是否应将前向传播的浮点输入张量转换为param_dtype类型。

class torch.distributed.fsdp.OffloadPolicy 

这个基类表示不进行卸载的策略,仅用作 offload_policy 参数的默认值。


class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True) 

该卸载策略将参数、梯度和优化器状态卸载到CPU。分片参数在all-gather操作前会从主机内存复制到设备内存。根据reshard_after_forward的设置,all-gather后的参数会被释放。

在反向传播过程中,分片梯度会从设备内存复制到主机内存,优化器步骤则在CPU上使用CPU优化器状态运行。

变量说明:

  • pin_memory ([bool]) – 是否固定分片参数和梯度的内存。固定内存可以实现更高效率的主机到设备/设备到主机的内存拷贝,并使拷贝操作与计算重叠。但固定内存无法被其他进程使用。若CPU内存不足,请将此参数设为False。(默认值:True


Tensor Parallelism - torch.distributed.tensor.parallel

Tensor Parallelism(张量并行,简称TP)构建在PyTorch DistributedTensor(DTensor)之上,提供多种并行风格:列并行(Colwise)、行并行(Rowwise)以及序列并行(Sequence Parallelism)。


警告:Tensor Parallelism API目前处于实验阶段,后续可能发生变更。

使用Tensor Parallelism并行化nn.Module的入口点是:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)

在PyTorch中通过基于用户指定方案并行化模块或子模块来应用张量并行。

我们根据parallelize_plan对模块或子模块进行并行化。该计划包含:

ParallelStyle,用于指示用户希望如何并行化模块或子模块。

用户还可以为每个模块的完全限定名称(FQN)指定不同的并行风格。

注意:parallelize_module仅接受一维DeviceMesh。如果使用二维或N维DeviceMesh,需先将DeviceMesh切片为一维子DeviceMesh再传入此API(例如device_mesh["tp"])。

参数

  • module (nn.Module) – 待并行化的模块。
  • device_mesh (DeviceMesh, 可选) – 描述DTensor设备网格拓扑的对象。若未指定,调用必须在DeviceMesh上下文中进行。
  • parallelize_plan (Union [ParallelStyle, Dict[str, ParallelStyle]], 可选) – 模块并行化方案。可以是包含张量并行输入/输出准备的ParallelStyle对象,或是模块FQN与其对应ParallelStyle对象的字典。若未指定,当前调用不会执行任何操作。

关键字参数

  • src_data_rank ( int , 可选) – 逻辑/全局张量的源数据秩,由distribute_tensor()用于将分片/副本分发到其他秩。默认使用每个DeviceMesh维度上的group_rank=0作为源数据以保持单设备语义。若显式传入Noneparallelize_module()将直接使用本地数据,而非通过分发/广播保持单设备语义。默认值:0

返回值:并行化后的nn.Module对象。

返回类型:Module


示例:

>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> >
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>> >

注意:对于像Attention、MLP层这样的复杂模块架构,我们建议将不同的ParallelStyle组合使用(例如ColwiseParallelRowwiseParallel),并通过parallelize_plan传递,以实现所需的分片计算。

Tensor Parallelism支持以下并行风格:

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)

以列式方式对兼容的 nn.Module 进行分区。当前支持 nn.Linearnn.Embedding

用户可将其与 RowwiseParallel 组合使用,以实现更复杂模块的分片(例如 MLP、Attention)。

关键字参数

  • input_layouts (Placement, 可选) – 输入张量在 nn.Module 中的 DTensor 布局,用于将输入张量标注为 DTensor。若未指定,则默认输入张量为副本形式。
  • output_layouts (Placement, 可选) – 输出张量在 nn.Module 中的 DTensor 布局,用于确保模块输出符合用户预期的布局。若未指定,输出张量将在最后一维分片。
  • use_local_output (bool, 可选) – 是否使用本地 torch.Tensor 而非 DTensor 作为模块输出,默认值:True。

返回

一个表示 nn.Module 列式分片的 ParallelStyle 对象。


示例

>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>> >
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...

注意:默认情况下,如果未指定 output_layoutsColwiseParallel 的输出会在最后一个维度上进行分片。如果存在需要特定张量形状的运算符(例如在配对的 RowwiseParallel 之前),请记住,若输出被分片,则可能需要根据分片后的尺寸调整该运算符。


class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)

以行方式对兼容的 nn.Module 进行分区。当前支持 nn.Linearnn.Embedding

用户可结合 ColwiseParallel 来实现更复杂模块的分片(例如 MLP、Attention)。

关键字参数

  • input_layouts (Placement, 可选) – 用于标注输入张量成为 DTensor 的布局参数。若未指定,则默认输入张量在最后一个维度分片。
  • output_layouts (Placement, 可选) – 确保模块输出符合用户预期布局的参数。若未指定,输出张量将被复制为全副本。
  • use_local_output (bool, 可选) – 是否使用本地 torch.Tensor 而非 DTensor 作为模块输出,默认值:True。

返回值
返回代表 nn.Module 行分片的 ParallelStyle 对象。


示例

>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>> >
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>...

class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)

SequenceParallel(序列并行)会复制兼容的nn.Module参数,并在序列维度上对分片输入执行分片计算。当前支持nn.LayerNormnn.Dropout以及RMSNorm的Python实现。

该模式实现了论文《减少大型Transformer模型中的激活重计算》中描述的操作。

若传入该nn.Module的输入是torch.Tensor,则假定输入已在序列维度分片,并将其转换为序列维度分片的DTensor。若传入的输入已是DTensor但未在序列维度分片,则会重新分配输入使其在序列维度分片。

nn.Module的输出将在序列维度分片。

关键字参数

  • sequence_dim (int, 可选) – 用于指定输入张量的序列维度,该参数会将输入张量标注为序列维度分片的DTensor,默认值:1
  • use_local_output (bool, 可选) – 是否对模块输出使用本地torch.Tensor而非DTensor,默认值:False

返回
一个代表nn.Module序列并行化的ParallelStyle对象。


示例

>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>> >
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>...

注意:SequenceParallel 风格假设 nn.Module 中的权重采用全1初始化(例如 nn.LayerNormRMSNorm,这些模块默认采用全1初始化)。如果这些模块的权重采用自定义初始化方式,则需要在并行化前后广播权重以确保权重被正确复制。

若只需为 nn.Module 的输入输出配置 DTensor 布局并执行必要的布局重分布,而无需将模块参数分发为 DTensor,在调用 parallelize_module 时可在 parallelize_plan 中使用以下 ParallelStyle

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)

配置 nn.Module 的输入参数,根据 input_layouts 在运行时将 nn.Module 的输入张量转换为 DTensor,并按照 desired_input_layouts 执行布局重分布。

关键字参数

  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) - 用于指定 nn.Module 输入张量的 DTensor 布局,该参数用于将输入张量转换为 DTensor。如果某些输入不是 torch.Tensor 或无需转换为 DTensor,需要用 None 作为占位符。默认值:None。
  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) - 用于指定 nn.Module 输入张量的期望 DTensor 布局,该参数确保 nn.Module 的输入具有期望的 DTensor 布局。此参数需要与 input_layouts 长度相同。默认值:None。
  • input_kwarg_layouts (Dict[str, Placement]) - 用于指定 nn.Module 输入关键字参数的 DTensor 布局,该参数用于将输入关键字参数张量转换为 DTensor。默认值:None。
  • desired_input_kwarg_layouts – (Dict[str, Placement]) - 用于指定 nn.Module 输入关键字参数的期望 DTensor 布局,该参数确保 nn.Module 的输入具有期望的 DTensor 布局。默认值:None。
  • use_local_output ([bool], 可选) - 是否对模块输入使用本地 torch.Tensor 而非 DTensor。默认值:False。

返回值:返回一个 ParallelStyle 对象,用于准备 nn.Module 输入的分片布局。


示例:

>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh, >>    parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, 
...), >>            desired_input_layouts=(Replicate(), None, None, 
...)
>>>         ), >>    }
>>> )

class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)

配置nn.Module的输出,在运行时根据output_layoutsnn.Module的输出张量转换为DTensor,并根据desired_output_layouts执行布局重分布。

关键字参数

  • output_layouts (Union[Placement , Tuple[Placement ]]) - 用于指定nn.Module输出张量的DTensor布局,当输出为torch.Tensor时将其转换为DTensor。如果某些输出不是torch.Tensor或无需转换,需要用None作为占位符。
  • desired_output_layouts (Union[Placement , Tuple[Placement ]]) - 指定nn.Module输出张量的目标DTensor布局,用于确保模块输出具有预期的DTensor布局。
  • use_local_output ([bool], 可选) - 是否对模块输出使用本地torch.Tensor而非DTensor,默认值为True。

返回值:返回一个ParallelStyle对象,用于设置nn.Module输出张量的分片布局。


示例:

>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh, >>    parallelize_plan = PrepareModuleOutput(
>>>         output_layouts=Replicate(), >>        desired_output_layouts=Shard(0)
>>>     )
>>> )

注意:当使用 Shard(dim) 作为上述 ParallelStyle 的输入/输出布局时,我们假设输入/输出激活张量在 TP 操作的 DeviceMesh 上沿张量维度 dim 均匀分片。例如,由于 RowwiseParallel 接受在最后一个维度分片的输入,它假设输入张量已在最后一个维度上均匀分片。对于非均匀分片的激活张量,用户可以直接将 DTensor 传入分区模块,并通过设置 use_local_output=False 使每个 ParallelStyle 处理后返回 DTensor,此时 DTensor 会记录非均匀分片信息。

对于 Transformer 这类模型,我们建议用户在 parallelize_plan 中同时使用 ColwiseParallelRowwiseParallel,以实现整个模型(包括注意力层和 MLP)的预期分片效果。

并行化的交叉熵损失计算(损失并行)可通过以下上下文管理器支持:

torch.distributed.tensor.parallel.loss_parallel()

一个支持损失并行计算的上下文管理器,当输入在类别维度上分片时,可以执行高效的并行化损失计算。目前仅支持交叉熵损失。

在该上下文管理器内,可以像往常一样使用 cross_entropy()CrossEntropyLoss,但需满足以下输入参数假设。

对应的 backward() 调用(如有)也需要在该上下文管理器下进行。

参数

  • input (DTensor) – 输入logits。假设在类别维度上分片。
  • target (Union [torch.Tensor, DTensor]) – 必须是真实类别索引(当前不支持类别概率)。假设在 DeviceMesh 上复制。
  • weight (Union [torch.Tensor, DTensor], 可选) – 如果提供,假设在 DeviceMesh 上复制。
  • label_smoothing – 当前不支持。

返回

一个复制的 DTensor


示例

这里手动创建了一个分片的DTensor来展示用法。实际应用中,它通常是TP模块的输出。


>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
>>>     loss.backward()
>>> ...

警告:loss_parallel API 目前处于实验阶段,后续可能会发生变化。



分布式优化器


警告:当前不支持在使用CUDA张量时使用分布式优化器

torch.distributed.optim提供了DistributedOptimizer,它接收一个远程参数列表(RRef)并在参数所在的worker节点上本地运行优化器。该分布式优化器可以使用任何本地优化器基类来在每个worker上应用梯度。


class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)

分布式优化器(DistributedOptimizer)接收分布在各个工作节点上的参数的远程引用,并为每个参数在本地应用指定的优化器。

该类通过 get_gradients() 方法来获取特定参数的梯度。

step() 的并发调用(无论来自相同或不同客户端)将在每个工作节点上串行执行——因为每个工作节点的优化器一次只能处理一组梯度。但无法保证完整的正向-反向-优化器序列会为一个客户端连续执行。这意味着应用的梯度可能不对应于给定工作节点上执行的最新正向传递。此外,不同工作节点之间也没有保证的执行顺序。

分布式优化器默认启用 TorchScript 来创建本地优化器,这样在多线程训练(如分布式模型并行)时,优化器更新不会被 Python 全局解释器锁(GIL)阻塞。目前大多数优化器都支持此功能。您也可以参考 PyTorch 教程中的这个示例来为自己的自定义优化器启用 TorchScript 支持。

参数

  • optimizer_class ([optim.Optimizer](optim.html#torch.optim.Optimizer "torch.optim.Optimizer")) – 要在每个工作节点上实例化的优化器类。
  • params_rref (list[RRef]) – 要优化的本地或远程参数的 RRef 列表。
  • args – 传递给每个工作节点上优化器构造函数的参数。
  • kwargs – 传递给每个工作节点上优化器构造函数的参数。

示例:

>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>> >
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>> >
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>> >
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD, >>     [rref1, rref2], >>     lr=0.05, >>  )
>>>   dist_optim.step(context_id)

step(context_id)

执行单次优化步骤。

该方法会在每个包含待优化参数的 worker 上调用 torch.optim.Optimizer.step(),并阻塞直至所有 worker 返回。提供的 context_id 将用于检索对应的 context,该上下文包含应应用于参数的梯度。

参数

  • context_id - 用于运行优化器步骤的自动求导上下文 ID。

class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)

封装一个任意的 torch.optim.Optimizer 并运行 post-local SGD。该优化器在每一步都运行本地优化器。

在预热阶段结束后,它会在应用本地优化器后定期对参数进行平均。

参数

  • optim ([Optimizer](optim.html#torch.optim.Optimizer "torch.optim.optimizer.Optimizer")) – 本地优化器。
  • averager (ModelAverager) – 用于运行 post-localSGD 算法的模型平均器实例。

示例

>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>>   PostLocalSGDState, >>  post_localSGD_hook, >>)
>>> >
>>> model = nn.parallel.DistributedDataParallel(
>>>    module, device_ids=[rank], output_device=rank
>>> )
>>> >
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>> >
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >># ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>>     optim=local_optim, >>    averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>> >
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >># and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>>    opt.zero_grad()
>>>    loss = loss_fn(output, labels)
>>>    loss.backward()
>>>    opt.step()

load_state_dict(state_dict)

这与 torch.optim.Optimizerload_state_dict() 方法功能相同,但还会将模型平均器的步长值恢复为提供的 state_dict 中保存的值。

如果 state_dict 中没有 "step" 条目,系统会发出警告并将模型平均器的步长初始化为 0。


state_dict()

这与 torch.optim.Optimizerstate_dict() 功能相同,但额外增加了一个条目用于记录模型平均器的步骤到检查点,以确保重新加载时不会再次引发不必要的预热过程。


step()

执行单次优化步骤(参数更新)。


class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)

包装一个任意的 optim.Optimizer 并在组内各 rank 之间分片其状态。

分片方式遵循 ZeRO 论文描述。

每个 rank 的本地优化器实例仅负责更新约 1 / world_size 的参数,因此只需维护 1 / world_size 的优化器状态。本地参数更新完成后,每个 rank 会将其参数广播给所有其他节点,以保持所有模型副本的状态一致。

ZeroRedundancyOptimizer 可与 torch.nn.parallel.DistributedDataParallel 结合使用,以降低单 rank 的峰值内存消耗。

ZeroRedundancyOptimizer 使用排序贪心算法在每个 rank 上打包若干参数。每个参数仅属于单一 rank,不会被分割到多个 rank。这种划分是任意的,可能与参数注册顺序或使用顺序不一致。

参数

  • params (Iterable) - 包含所有待分片参数的 torch.Tensordict 的可迭代对象

关键字参数

  • optimizer_class (torch.nn.Optimizer) - 本地优化器的类

  • process_group (ProcessGroup, 可选) - torch.distributedProcessGroup(默认使用 torch.distributed.init_process_group() 初始化的 dist.group.WORLD

  • parameters_as_bucket_view ([bool], 可选) - 若为 True,参数会被打包到桶中以加速通信,且 param.data 字段指向桶视图的不同偏移量;若为 False,则单独通信每个参数,且每个 params.data 保持不变(默认:False

  • overlap_with_ddp ([bool], 可选) - 若为 Truestep() 将与 DistributedDataParallel 的梯度同步过程重叠执行,这要求:
    1、optimizer_class 必须是函数式优化器或具有等效功能
    2、需注册来自 ddp_zero_hook.py 的 DDP 通信钩子
    参数会打包为与 DistributedDataParallel 匹配的桶,此时 parameters_as_bucket_view 参数将被忽略。

    若为 Falsestep() 将在反向传播后独立执行(默认行为)(默认:False

  • **defaults - 其他尾部参数,将传递给本地优化器


示例

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential([nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(), >>    optimizer_class=torch.optim.Adam, >>    lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告:目前 ZeroRedundancyOptimizer 要求所有传入参数必须是相同密集类型。

警告:如果设置 overlap_with_ddp=True,请注意以下情况:根据当前 DistributedDataParallelZeroRedundancyOptimizer 重叠的实现方式,前两到三次训练迭代不会执行优化器参数更新(具体次数取决于 static_graph=Falsestatic_graph=True)。这是因为需要获取 DistributedDataParallel 使用的梯度分桶策略信息——当 static_graph=False 时该信息在第二次前向传播后才会确定,而 static_graph=True 时则需等到第三次前向传播。解决方法之一是在训练数据前添加虚拟输入。

警告:ZeroRedundancyOptimizer 仍处于实验阶段,功能可能发生变化。


add_param_group(param_group)

Optimizerparam_groups 添加一个参数组。

在微调预训练网络时,这个方法非常有用——随着训练进行,原本冻结的层可以变为可训练状态,并添加到 Optimizer 中。

参数说明

  • param_group ( dict ) - 指定待优化的参数及该组特有的优化选项。

警告说明
此方法会更新所有分区的参数分片,但必须在所有计算节点上调用。若仅部分节点调用该方法,会导致训练挂起,因为通信原语的调用依赖于托管参数,且要求所有节点必须基于同一组参数参与计算。


consolidate_state_dict(to=0)

将各 rank 的 state_dict 列表(每个 rank 一个)合并到目标 rank 上。

参数

  • to (int) – 接收优化器状态的 rank(默认值:0)。

抛出异常
RuntimeError – 若 overlap_with_ddp=True 且此方法在 ZeroRedundancyOptimizer 实例完全初始化前被调用(完全初始化需等待 DistributedDataParallel 梯度桶重建完成)。

警告:必须在所有 rank 上调用此方法。


property join_device:  device 

返回默认设备。


join_hook(**kwargs)

返回 ZeRO 连接钩子。

该钩子通过遮蔽优化器步骤中的集体通信操作,实现在非均匀输入数据上的训练。

调用此钩子前必须正确设置梯度。

参数

  • kwargs ( dict ) – 一个包含运行时修改连接钩子行为的关键字参数的字典;所有共享同一连接上下文管理器的 Joinable 实例都会收到相同的 kwargs 值。

此钩子不支持任何关键字参数,即 kwargs 未被使用。


property join_process_group:  Any

返回进程组。


load_state_dict(state_dict)

从输入的 state_dict 中加载与指定 rank 相关的状态,并根据需要更新本地优化器。

参数

  • state_dict ( dict ) – 优化器状态;应为调用 state_dict() 返回的对象。

抛出异常

RuntimeError – 如果 overlap_with_ddp=True 且此方法在 ZeroRedundancyOptimizer 实例完全初始化之前被调用(完全初始化发生在 DistributedDataParallel 梯度桶重建完成之后)。


state_dict()

返回当前节点已知的最后一个全局优化器状态。

可能引发的异常

RuntimeError —— 当满足以下任一条件时抛出:
1、设置overlap_with_ddp=True时,在ZeroRedundancyOptimizer实例完全初始化前调用本方法(初始化完成标志是DistributedDataParallel梯度桶重建完成);
2、调用本方法前未先调用consolidate_state_dict()方法。

返回类型:dict[str , Any ]


step(closure=None, **kwargs)

执行单次优化器步骤并同步所有进程间的参数。

参数

  • closure (Callable) – 用于重新评估模型并返回损失值的闭包函数;大多数优化器可省略此参数。

返回值:取决于底层本地优化器的可选损失值。

返回类型:Optional[float]

注意:所有额外参数都将原样传递给基础优化器。



流水线并行


注意:torch.distributed.pipelining 目前处于 alpha 阶段且正在开发中。API 可能会发生变化。该功能是从 PiPPy 项目迁移而来。


为什么需要流水线并行?

流水线并行是深度学习中基础的并行方式之一。它允许将模型执行过程进行划分,使得多个微批次能够同时执行模型代码的不同部分。流水线并行在以下场景中尤为有效:

  • 大规模训练
  • 带宽受限的集群
  • 大模型推理

这些场景的共同特点是:单个设备的计算量无法掩盖传统并行方式(如FSDP的权重全收集操作)带来的通信开销。


什么是 torch.distributed.pipelining

虽然流水线并行在扩展性方面前景广阔,但其实现往往颇具挑战性,因为它不仅需要对模型权重进行划分,还需要拆分模型执行过程。执行过程的划分通常需要对模型代码进行侵入式修改。另一重复杂性来源于分布式环境中的微批次调度,同时还需考虑数据流依赖关系

pipelining 包提供了一套自动化工具链,能够自动完成上述操作,从而在通用模型上轻松实现流水线并行。

该工具包包含两个核心组件:拆分前端分布式运行时。拆分前端直接接收原始模型代码,将其分割为多个"模型分区",并捕获数据流关系。分布式运行时则在不同设备上并行执行流水线阶段,处理微批次划分、调度、通信和梯度传播等任务。

总体而言,pipelining 包提供以下核心功能:

  • 基于简单配置的模型代码自动拆分
  • 全面支持多种流水线调度策略(包括GPipe、1F1B、交错式1F1B和循环BFS),并提供自定义调度器开发基础设施
  • 原生支持跨主机流水线并行(这是PP技术最典型的应用场景,适用于低速网络互联环境)
  • 可与PyTorch其他并行技术(如数据并行DDP/FSDP或张量并行)组合使用。TorchTitan项目展示了在Llama模型上实现"3D并行"的应用案例。

第一步:构建 PipelineStage

在使用 PipelineSchedule 之前,我们需要先创建 PipelineStage 对象,这些对象封装了在该阶段运行的模型部分。PipelineStage 负责分配通信缓冲区,并创建发送/接收操作以与对等节点通信。它管理中间缓冲区,例如尚未被消费的前向输出,并提供运行阶段模型反向传播的实用工具。

PipelineStage 需要知道阶段模型的输入和输出形状,以便正确分配通信缓冲区。这些形状必须是静态的,即在运行时,形状不能每一步都变化。如果运行时形状与预期形状不匹配,将抛出 PipeliningShapeError 异常。在与其他并行技术组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage 在运行时知道阶段模块输出的正确形状(和数据类型)。

用户可以直接构造 PipelineStage 实例,方法是传入一个 nn.Module,表示应在该阶段运行的模型部分。这可能需要对原始模型代码进行修改。具体示例请参见选项1:手动拆分模型。

或者,拆分前端可以使用图分区技术自动将模型拆分为一系列 nn.Module。此技术要求模型可以通过 torch.Export 进行追踪。生成的 nn.Module 与其他并行技术的组合仍处于实验阶段,可能需要一些变通方法。如果用户难以修改模型代码,使用此前端可能更具吸引力。更多信息请参见选项2:自动拆分模型。


步骤2:使用PipelineSchedule执行

现在我们可以将PipelineStage附加到流水线调度器上,并通过输入数据运行该调度器。以下是一个GPipe示例:

from torch.distributed.pipelining import ScheduleGPipe# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically if rank == 0:schedule.step(x)
else:output = schedule.step()

请注意,上述代码需要在每个工作节点上启动,因此我们使用一个启动器服务来启动多个进程:

torchrun --nproc_per_node=2 example.py

模型分割方案


方案一:手动拆分模型

要直接构建一个PipelineStage,用户需要负责提供一个单独的nn.Module实例,该实例需包含相关的nn.Parametersnn.Buffers,并定义一个forward()方法来执行该阶段相关的操作。例如,Torchtitan中定义的Transformer类精简版展示了一种构建易于分区模型的模式。


class Transformer(nn.Module):def __init__(self, model_args: ModelArgs):super().__init__()self.tok_embeddings = nn.Embedding(...)# Using a ModuleDict lets us delete layers without affecting names,  # ensuring checkpoints will correctly save and load.self.layers = torch.nn.ModuleDict()for layer_id in range(model_args.n_layers):self.layers[str(layer_id)] = TransformerBlock(...)self.output = nn.Linear(...)def forward(self, tokens: torch.Tensor):# Handling layers being 'None' at runtime enables easy pipeline splittingh = self.tok_embeddings(tokens) if self.tok_embeddings else tokensfor layer in self.layers.values():h = layer(h, self.freqs_cis)h = self.norm(h) if self.norm else houtput = self.output(h).float() if self.output else hreturn output

以这种方式定义的模型可以轻松按阶段进行配置,具体步骤如下:

首先初始化整个模型(使用 meta-device 避免内存不足错误),然后删除该阶段不需要的层,最后创建一个封装模型的 PipelineStage。例如:

with torch.device("meta"):assert num_stages == 2, "This is a simple 2-stage example"# we construct the entire model, then delete the parts we do not need for this stage# in practice, this can be done using a helper function that automatically divides up layers across stages.model = Transformer()if stage_index == 0:# prepare the first stage modeldel model.layers["1"]model.norm = Nonemodel.output = Noneelif stage_index == 1:# prepare the second stage modelmodel.tok_embeddings = Nonedel model.layers["0"]from torch.distributed.pipelining import PipelineStagestage = PipelineStage(model,  stage_index,  num_stages,  device, )

当与其他数据或模型并行技术结合使用时,如果模型分块的输出形状/数据类型会受到影响,可能还需要指定 output_args


选项2:自动拆分模型

如果您拥有完整模型,且不想花费时间将其修改为一系列"模型分区",那么pipeline API可以为您提供帮助。以下是一个简单示例:

class Model(torch.nn.Module):def __init__(self) -None:super().__init__()self.emb = torch.nn.Embedding(10, 3)self.layers = torch.nn.ModuleList(Layer() for _ in range(2))self.lm = LMHead()def forward(self, x: torch.Tensor) -torch.Tensor:x = self.emb(x)for layer in self.layers:x = layer(x)x = self.lm(x)return x

如果打印模型,我们会看到多个层级结构,这使得手动拆分变得困难:

Model((emb): Embedding(10, 3)(layers): ModuleList((0-1): 2 x Layer((lin): Linear(in_features=3, out_features=3, bias=True)))(lm): LMHead((proj): Linear(in_features=3, out_features=3, bias=True))
)

让我们看看 pipeline API 的工作原理:

from torch.distributed.pipelining import pipeline, SplitPoint# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])pipe = pipeline(module=mod, mb_args=(x,), split_spec={"layers.1": SplitPoint.BEGINNING, }
)

pipeline API 根据给定的 split_spec 对模型进行分割,其中:

SplitPoint.BEGINNING 表示在 forward 函数中特定子模块执行之前添加分割点,类似地,SplitPoint.END 表示在此类子模块执行之后添加分割点。

如果我们执行 print(pipe),可以看到:

GraphModule((submod_0): GraphModule((emb): InterpreterModule()(layers): Module((0): InterpreterModule((lin): InterpreterModule())))(submod_1): GraphModule((layers): Module((1): InterpreterModule((lin): InterpreterModule()))(lm): InterpreterModule((proj): InterpreterModule()))
)def forward(self, x):submod_0 = self.submod_0(x);  x = Nonesubmod_1 = self.submod_1(submod_0);  submod_0 = Nonereturn (submod_1,)

“模型分区”由子模块(submod_0submod_1)表示,每个子模块都使用原始模型的操作、权重和层次结构进行重构。此外,还重构了一个“根级别”的forward函数,用于捕获这些分区之间的数据流。后续将由流水线运行时以分布式方式重放这些数据流。

Pipe对象提供了一个用于检索“模型分区”的方法:

stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

返回的 stage_mod 是一个 nn.Module,你可以用它来创建优化器、保存或加载检查点,或者应用其他并行策略。

Pipe 还允许你基于给定的 ProcessGroup 在设备上创建分布式阶段运行时环境。


stage = pipe.build_stage(stage_idx, device, group)

或者,如果您希望在修改 stage_mod 后稍后再构建 stage 运行时,可以使用 build_stage API 的函数式版本。例如:

from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParalleldp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)

注意:pipeline 前端使用追踪器 (torch.export) 将你的模型捕获为单一计算图。如果你的模型无法实现全图捕获,可以使用下方提供的手动前端。


Hugging Face 示例

在最初创建此包的 PiPPy 代码库中,我们保留了基于未修改的 Hugging Face 模型的示例。请参阅 examples/huggingface 目录。


示例包括:

  • GPT2
  • Llama

技术深度解析


pipeline API 如何分割模型?

首先,pipeline API 通过追踪模型将其转换为有向无环图(DAG)。它使用 PyTorch 2 的全图捕获工具 torch.export 来追踪模型。

然后,它将一个阶段所需的操作和参数分组到重建的子模块中:submod_0submod_1 等。

与传统的子模块访问方法(如 Module.children())不同,pipeline API 不仅切割模型的模块结构,还会切割模型的 forward 函数。

这是必要的,因为像 Module.children() 这样的模型结构仅捕获 Module.__init__() 期间的信息,而不会捕获任何关于 Module.forward() 的信息。换句话说,Module.children() 缺少以下对流水线至关重要的信息:

  • forward 中子模块的执行顺序
  • 子模块之间的激活流
  • 子模块之间是否存在任何函数式操作(例如,reluadd 操作不会被 Module.children() 捕获)。

相反,pipeline API 确保 forward 行为被完整保留。它还捕获分区之间的激活流,帮助分布式运行时无需人工干预即可正确执行发送/接收调用。

pipeline API 的另一个灵活性是,分割点可以位于模型层次结构的任意级别。在分割后的分区中,与该分区相关的原始模型层次结构会被重建,且不会带来额外开销。因此,指向子模块或参数的完全限定名称(FQN)仍然有效,依赖 FQN 的服务(如 FSDP、TP 或检查点)几乎无需代码更改即可继续运行。


实现自定义调度策略

您可以通过扩展以下两个基类之一来实现自己的流水线调度策略:

  • PipelineScheduleSingle
  • PipelineScheduleMulti

PipelineScheduleSingle 适用于每个计算节点仅分配单个阶段的调度策略。
PipelineScheduleMulti 则适用于每个计算节点分配多个阶段的调度策略。

例如:

  • ScheduleGPipeSchedule1F1BPipelineScheduleSingle 的子类
  • ScheduleInterleaved1F1BScheduleLoopedBFSScheduleInterleavedZeroBubble 以及 ScheduleZBVZeroBubble 则是 PipelineScheduleMulti 的子类

日志记录

您可以通过设置torch._logging中的TORCH_LOGS环境变量来启用额外的日志记录功能:

  • TORCH_LOGS=+pp 会显示logging.DEBUG及以上级别的日志信息
  • TORCH_LOGS=pp 会显示logging.INFO及以上级别的日志信息
  • TORCH_LOGS=-pp 会显示logging.WARNING及以上级别的日志信息

API 参考


模型拆分 API

以下一组 API 可将您的模型转换为流水线表示形式。


class torch.distributed.pipelining.SplitPoint(value)

表示在子模块执行过程中可插入切分点的枚举类型。

:ivar BEGINNING: 表示在前向函数中某个子模块执行之前添加切分点。

:ivar END: 表示在前向函数中某个子模块执行之后添加切分点。


torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)

根据规范拆分模块。

更多详情请参阅 Pipe。

参数

  • module ( Module ) – 待拆分的模块。
  • mb_args ( tuple [Any , ...]) – 示例位置输入,以微批次形式提供。
  • mb_kwargs (Optional[dict[str, Any ]]) – 示例关键字输入,以微批次形式提供。(默认值:None)
  • split_spec (Optional[dict[str, torch.distributed.pipelining._IR.SplitPoint]]) – 使用子模块名称作为拆分标记的字典。(默认值:None)
  • split_policy (Optional[Callable [[GraphModule)],GraphModule]]) – 用于拆分模块的策略。(默认值:None)

返回类型:返回 Pipe 类的流水线表示形式。


class torch.distributed.pipelining.Pipe(split_gm, num_stages, has_loss_and_backward, loss_spec)

torch.distributed.pipelining.pipe_split()

pipe_split 是一个特殊运算符,用于标记模块中各阶段之间的边界。它的作用是将模块拆分为多个阶段。如果你以即时执行模式运行带注解的模块,该运算符不会产生任何效果。


示例:

>>> def forward(self, x):
>>>     x = torch.mm(x, self.mm_param)
>>>     x = torch.relu(x)
>>>     pipe_split()
>>>     x = self.lin(x)
>>>     return x

上述示例将被拆分为两个阶段。


微批次工具集


class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)

用于指定输入分块的类


torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)

根据给定的参数序列(args和kwargs),按照各自的分块规格将它们分割成多个块。

参数说明:

  • args (tuple[Any, ...]) - 参数元组
  • kwargs (Optional[dict[str, Any]]) - 关键字参数字典
  • chunks (int) - 要将args和kwargs分割成的块数
  • args_chunk_spec (Optional[tuple[torch.distributed.pipelining.microbatch.TensorChunkSpec, ...]]) - args的分块规格,形状与args相同
  • kwargs_chunk_spec (Optional[dict[str, torch.distributed.pipelining.microbatch.TensorChunkSpec]]) - kwargs的分块规格,形状与kwargs相同

返回值说明:

  • args_split: 分割后的参数列表
  • kwargs_split: 分割后的关键字参数字典列表
  • 返回类型: args_split

torch.distributed.pipelining.microbatch.merge_chunks(chunks, chunk_spec)

根据分块规范将给定的分块列表合并为单个值。

参数

  • chunks (list[Any ]) - 分块列表
  • chunk_spec - 分块的分块规范

返回值:合并后的值

返回类型:值


流水线阶段


class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args=None, output_args=None, group=None, dw_builder=None)

一个表示流水线并行设置中流水线阶段的类。

PipelineStage 假设模型采用顺序分区方式,即模型被分割成多个块,其中一个块的输出作为下一个块的输入,不存在跳跃连接。

PipelineStage 通过按线性顺序将 stage0 的输出传播到 stage1 等方式,自动执行运行时形状/数据类型推断。若要绕过形状推断,需将 input_args 和 output_args 传递给每个 PipelineStage 实例。

参数

  • submodule (nn.Module) – 该阶段封装的 PyTorch 模块。
  • stage_index ( int ) – 本阶段的 ID。
  • num_stages ( int ) – 阶段总数。
  • device ( torch.device ) – 本阶段所在的设备。
  • input_args (Union[torch.Tensor, Tuple[torch.tensor]], 可选) – 子模块的输入参数。
  • output_args (Union[torch.Tensor, Tuple[torch.tensor]], 可选) – 子模块的输出参数。
  • group (dist.ProcessGroup, 可选) – 分布式训练的进程组。若为 None,则使用默认组。
  • dw_builder (Optional[Callable[[], Callable[...*, None]]) – 若提供,dw_builder 将构建一个新的 dw_runner 函数,该函数会为 F、I、W(前向、输入、权重)零气泡调度执行 W 动作(输入权重)。

torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)

创建一个流水线阶段,给定需要被该阶段包装的stage_module以及流水线信息。


参数

  • stage_module ( torch.nn.Module ) – 需要被该阶段包装的模块
  • stage_index ( int ) – 该阶段在流水线中的索引
  • pipe_info (PipeInfo) – 关于流水线的信息,可通过pipe.info()获取
  • device ( torch.device ) – 该阶段使用的设备
  • group (Optional[dist.ProcessGroup]) – 该阶段使用的进程组

返回一个可与PipelineSchedules一起运行的流水线阶段。

返回类型:_PipelineStage


流水线调度


class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

GPipe调度方案。

采用填充-排空的方式处理所有微批次数据。


class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

1F1B调度方案。

在稳定状态下,将对微批次执行一次前向和一次后向操作。


class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

交错式1F1B调度方案。

详情请参阅https://arxiv.org/pdf/2104.04473。

在稳定状态下,该方案会对微批次执行一次前向和一次后向计算,并支持每个rank处理多个阶段。当微批次准备好进行多个本地阶段计算时,交错式1F1B会优先处理较早的微批次(也称为"深度优先")。

该调度方案与原始论文基本相似,主要区别在于放宽了num_microbatch % pp_size == 0的要求。使用flex_pp调度时,我们会得到num_rounds = max(1, n_microbatches // pp_group_size),只要满足n_microbatches % num_rounds == 0即可正常工作。例如:

1、pp_group_size = 4,n_microbatches = 10时,num_rounds = 2且n_microbatches % 2 == 0
2、pp_group_size = 4,n_microbatches = 3时,num_rounds = 1且n_microbatches % 1 == 0


class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None, scale_grads=True)

广度优先流水线并行。

详情请参阅https://arxiv.org/abs/2211.05953。

与交错式1F1B类似,循环BFS支持每个rank运行多个阶段。

不同之处在于,当多个本地阶段的微批次准备就绪时,循环BFS会优先处理较早的阶段,一次性运行所有可用的微批次。


class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

零气泡计划(ZBV变体)。

详情请参见 https://arxiv.org/pdf/2401.10241 第6节。

此计划要求每个等级恰好有两个阶段。

该计划将在稳定状态下对微批次的输入执行一次前向传播和一次后向传播,并支持每个等级有多个阶段。使用相对于权重的后向传播来填补管道气泡。

只有当时间前向传播 == 时间后向传播输入 == 时间后向传播权重时,这个ZB-V计划才具有“零气泡”属性。

实际上,对于真实模型来说,这不太可能是真的,所以可以选择实现一个贪婪调度器来处理不平等或不平衡的时间。


class torch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

零气泡调度方案(ZBV变体)。

详情请参阅https://arxiv.org/pdf/2401.10241第6节。

该调度方案要求每个rank(计算节点)严格使用两个阶段。

在稳定状态下,该方案会对微批次的输入执行一次前向传播和一次反向传播,并支持每个rank多阶段处理。通过权重反向传播来填补流水线气泡。

只有当满足以下条件时,该ZB-V调度方案才具备"零气泡"特性:前向传播时间 == 输入反向传播时间 == 权重反向传播时间。

实际应用中,真实模型很难满足这一条件。因此,针对时间不均衡的情况,可以改用贪心调度器实现。


class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

单阶段计划的基础类。

实现了步骤方法。

派生类应该实现 _step_microbatches 方法。

根据 scale_grads 参数,梯度会根据 num_microbatches 进行缩放,默认为 True。这个设置应该与您的 loss_fn 的配置相匹配,loss_fn 可能是平均损失(scale_grads=True)或总和损失(scale_grads=False)。


step(*args, target=None, losses=None, **kwargs)

运行一次管道计划的迭代,使用 whole-batch 输入。

将自动将输入分块为微批次,并根据计划实现依次处理微批次。

args: 模型的位置参数(与非管道情况相同)。

kwargs: 模型的关键字参数(与非管道情况相同)。

target: 损失函数的目标。

losses: 用于存储每个微批次的损失的列表。


class torch.distributed.pipelining.schedules.PipelineScheduleMulti(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, use_full_backward=None, scale_grads=True)

多阶段计划的基础类。

实现了步骤方法。

根据 scale_grads 参数,梯度会根据 num_microbatches 进行缩放,默认为 True。这个设置应该与您的 loss_fn 的配置相匹配,loss_fn 可能是平均损失(scale_grads=True)或总和损失(scale_grads=False)。


step(*args, target=None, losses=None, **kwargs)

运行管道调度的一次迭代,使用全批次输入。

该方法会自动将输入切分为微批次,并根据调度实现依次处理这些微批次。

参数说明:

  • args: 传递给模型的位置参数(与非管道式情况相同)
  • kwargs: 传递给模型的关键字参数(与非管道式情况相同)
  • target: 损失函数的目标值
  • losses: 用于存储每个微批次损失值的列表


分布式检查点 - torch.distributed.checkpoint

分布式检查点(DCP)支持并行地从多个计算节点加载和保存模型。它能够处理加载时的重分片操作,这使得在一个集群拓扑中保存的模型可以加载到另一个不同拓扑的集群中。

DCP与torch.save和torch.load在几个重要方面存在差异:

  • 每个检查点会生成多个文件,每个计算节点至少对应一个文件
  • 它以原地方式操作,这意味着模型需要先分配其数据存储空间,DCP会直接使用这些预分配的存储空间

以下是加载和保存检查点的主要入口方法:

附加资源:

  • 分布式检查点(DCP)入门指南
  • 使用分布式检查点(DCP)进行异步保存
  • TorchTitan 检查点功能文档
  • TorchTitan DCP 实现代码

class torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value)

异步检查点类型的枚举。


torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, no_dist=False)

以SPMD风格保存分布式模型。

此函数与torch.save()不同,它通过让每个rank仅保存本地分片来处理ShardedTensorDTensor

对于每个Stateful对象(同时具有state_dictload_state_dict方法),保存操作会在序列化前调用state_dict


警告:不同PyTorch版本间保存的state_dict不保证具有向后兼容性。


警告:如果使用process_group参数,请确保只有该组的rank调用save_state_dict,且state_dict中的所有数据都属于该组。


注意:当为FSDP的ShardingStrategy.HYBRID_SHARD保存检查点时,只有一个shard_group应调用save_state_dict,且需要传入对应的process_group。


注意:

如果没有可用的进程组,此函数会假定意图是在本地进程中保存state_dict。


参数

  • state_dict (Dict[str, Any]) – 要保存的state_dict。
  • checkpoint_id (Union[str, os.PathLike, None]) – 检查点实例的ID。checkpoint_id的具体含义取决于存储类型,可以是文件夹路径、文件路径,或者键值存储中的键名(默认:None)。
  • storage_writer (Optional[StorageWriter]) – 用于执行写入操作的StorageWriter实例。如果未指定,DCP会根据checkpoint_id自动推断写入器。如果checkpoint_id也为None,将抛出异常(默认:None)。
  • planner (Optional[SavePlanner]) – SavePlanner实例。如果未指定,将使用默认planner(默认:None)。
  • process_group (Optional[ProcessGroup]) – 用于跨rank同步的进程组(默认:None)。
  • no_dist ([bool]) – 如果为True,此函数将假定意图是在不使用跨rank同步的情况下加载检查点(默认:False)。

返回

已保存检查点的元数据对象。

返回类型:Metadata


示例


>>> my_model = MyModule()

>>> state_dict = {"model": my_model}

>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
...     "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict, >>    storage_writer=fs_storage_writer, >>)

注意save_state_dict 使用集合通信(collectives)来协调不同进程间的写入操作。

对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生前移至 GPU 设备。

此时,所使用的设备由 torch.cuda.current_device() 指定,用户需自行确保通过 torch.cuda.set_device() 正确设置,使每个进程独占一个 GPU。


torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, async_checkpointer_type=AsyncCheckpointerType.THREAD)

save 的异步版本。该代码首先将 state_dict 从暂存区移出到暂存存储(默认为 CPU 内存),然后在单独的线程中调用保存操作。

警告:此功能为实验性质,可能会发生变化。

参数

  • state_dict (Dict[str, Any]) – 要保存的 state_dict。
  • checkpoint_id (Union[str,* os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的具体含义取决于存储类型。它可以是文件夹路径或文件路径。如果存储是键值存储,它也可以是键。(默认值:None
  • storage_writer (Optional[StorageWriter)]) – 用于执行 ‘stage’ 和 ‘save’ 的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会抛出异常。(默认值:None
  • planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定,将使用默认的 planner。(默认值:None
  • process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认值:None

返回值:一个持有保存操作返回的 Metadata 对象的 future。

返回类型:Future

示例


>>> my_model = MyModule()

>>> state_dict = {"model": my_model}

>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
...     "/checkpoint/1"
... )
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict, >>    storage_writer=fs_storage_writer, >>)
>>> >
>>> # ... do some work ...
>>> >
>>> checkpoint_future.result()

torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)

此方法已弃用。请改用 save

返回类型:Metadata


torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None, no_dist=False)

以SPMD风格将检查点加载到分布式状态字典中。

每个进程提供的state_dict必须包含相同的键。键不匹配可能导致挂起或错误。如果不确定,可以使用utils._assert_same_keys API进行检查(但可能会产生通信开销)。

每个进程会尝试读取最少量的数据来填充请求的state_dict。当加载ShardedTensorDTensor实例时,每个进程仅读取其本地分片的数据。

对于每个Stateful对象(同时具有state_dictload_state_dict方法),加载操作会先调用state_dict,然后尝试反序列化,反序列化完成后调用load_state_dict

对于非Stateful对象,加载操作会反序列化对象,然后在state_dict中用反序列化后的对象替换原对象。


警告:state_dict中的所有张量必须在调用此函数之前分配到目标设备上。

所有非张量数据使用torch.load()加载,并在state_dict中就地修改。


警告:用户必须在根模块上调用load_state_dict,以确保加载后处理和非张量数据正确传播。


参数

  • state_dict (Dict[str, Any]) – 要加载检查点的状态字典。
  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储类型。可以是文件夹路径、文件路径,如果存储是键值存储也可以是键名。(默认: None)
  • storage_reader (Optional[[StorageReader](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.StorageReader "torch.distributed.checkpoint.StorageReader")]) – 用于执行读取操作的StorageWriter实例。如果未指定,DCP会根据checkpoint_id自动推断读取器。如果checkpoint_id也为None,则会抛出异常。(默认: None)
  • planner (Optional[LoadPlanner]) – LoadPlanner实例。如果未指定,将使用默认规划器。(默认: None)
  • process_group (Optional[ProcessGroup]) – 用于跨进程同步的ProcessGroup。(默认: None)
  • no_dist ([bool]) – 如果为True,此函数将假定目的是在不使用跨进程同步的情况下加载检查点。(默认: False)

返回

无。

返回类型:无


示例


>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
...     "/checkpoint/1"
... )

>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict, >>    storage_reader=fs_storage_reader, >>)

>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to >># ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意load_state_dict 使用集合通信来协调跨进程的读取操作。

对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生前移至 GPU 设备。

此时使用的设备由 torch.cuda.current_device() 指定,用户需自行确保通过 torch.cuda.set_device() 正确设置,使每个进程独占一个 GPU。


torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)

该方法已弃用,请改用 load

以下模块还可用于对异步检查点(torch.distributed.checkpoint.async_save)使用的暂存机制进行额外定制:

class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)

该协议旨在为dcp.async_save提供定制化和扩展能力,允许用户在并行执行常规dcp.save流程前自定义数据暂存方式。

预期操作顺序(具体定义于torch.distributed.state_dict_saver.async_save)如下:

1、AsyncStager.stage_data(state_dict):此调用为AsyncStager提供"暂存"state_dict的机会。此处的暂存预期目的是创建state_dict的"训练安全"表示形式,这意味着暂存完成后对模块数据的任何更新都不应反映在该方法返回的state_dict中。例如默认情况下,会在CPU内存中创建整个state_dict的副本并返回,从而允许用户继续训练而不影响正在被序列化的数据。

2、对暂存返回的state_dict并行调用dcp.save。该调用负责序列化state_dict并将其写入存储。

3、若AsyncStager.should_synchronize_after_execute为True,该方法将在序列化线程启动后、从dcp.async_save返回前立即调用。若设为False,则假定用户已定义自定义同步点以进一步优化训练循环中的保存延迟(例如通过将暂存与前向/反向传播重叠),此时用户需在适当时机调用AsyncStager.synchronize_staging


property should_synchronize_after_execute:  bool 

是否在执行阶段后进行同步。


stage(state_dict)

返回一个"暂存"状态的 state_dict 副本。该暂存副本的特性是:在 stage 调用完成后,不会受到任何后续更新的影响。

返回类型:dict[str , Union [~StatefulT, Any ]


synchronize_staging()

如果阶段以某种方式异步进行,应调用此方法以确保暂存完成,此时可以安全地开始修改原始 state_dict。


class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)

一个实现了 AsyncStager 的类,将 state_dict 暂存到 CPU 内存中,并阻塞直到复制完成。

该实现还提供了使用固定内存来优化暂存延迟的选项。

注意:在这种情况下,synchronize_staging 是一个空操作。


stage(state_dict)

返回一个位于CPU上的state_dict副本。

返回类型:dict[str, Union[~StatefulT, Any]]


synchronize_staging()

无操作函数,因为暂存是阻塞式的。

除了上述入口点外,如下所述的有状态对象在保存/加载过程中提供了额外的自定义功能

… automodule:: torch.distributed.checkpoint.stateful


class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)

支持检查点(checkpoint)与恢复功能的对象状态协议。


load_state_dict(state_dict)

从提供的 state_dict 恢复对象的状态。

参数

  • state_dict ( dict[str, Any ]) – 用于恢复的状态字典

state_dict()


Objects should return their state_dict representation as a dictionary.
The output of this function will be checkpointed, and later restored in load_state_dict().


Warning: Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.

Returns
The objects state dict

Return type
Dict

This example shows how to use Pytorch Distributed Checkpoint to save a FSDP model.

The following types define the IO interface used during checkpoint:


class torch.distributed.checkpoint.StorageReader

Interface used by load_state_dict to read from storage.

One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.

A subclass should expected the following sequence of calls by load_state_dict:

0、(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
1、(all ranks) read_metadata()
2、(all ranks) set_up_storage_reader()
3、(all ranks) prepare_local_plan()
4、(coordinator) prepare_global_plan()
5、(all ranks) read_data()


ABSTRACT  prepare_global_plan(plans)

Perform centralized planning of storage loading.

This method is only called on the coordinator instance.

While this method can produce a completely different plan, the preferred
way is to store storage specific data in LoadPlan::storage_data.


Parameters

  • plans (list[torch.distributed.checkpoint.planner.LoadPlan]) – A list of LoadPlan instances, one for each rank.

Returns
A list of transformed LoadPlan after storage global planning

Return type
list [torch.distributed.checkpoint.planner.LoadPlan]


ABSTRACT prepare_local_plan(plan)

Perform storage-specific local planning.

While this method can produce a completely different plan, the recommended
way is to store storage specific data in LoadPlan::storage_data.


Parameters

  • plan (LoadPlan) – The local plan from the LoadPlan in use.

Returns
A transformed LoadPlan after storage local planning

Return type
LoadPlan


ABSTRACT read_data(plan, planner)

Read all items from plan using planner to resolve the data.

A subclass should call LoadPlanner::load_bytes to deserialize a BytesIO
object into the right place.

A subclass should call LoadPlanner::resolve_tensor to get access to the tensors that in should load data into.

It’s the StorageLayer responsibility to properly schedule any cross device copies
required.


Parameters

  • plan (LoadPlan) – The local plan to execute on * planner (LoadPlanner) – The planner object to use to resolve items.

Returns
A future that completes once all reads are finished.

Return type
Future [None]


read_metadata()

摘要


Read the checkpoint metadata.

Returns
The metadata object associated with the checkpoint being loaded.

Return type
Metadata


ABSTRACT  reset(checkpoint_id=None)

Calls to indicates a brand new checkpoint read is going to happen.
A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.


Parameters

  • checkpoint_id (Union[str,* os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id
    depends on the storage. It can be a path to a folder or to a file.
    It can also be a key if the storage is more like a key-value store.
    (Default: None)


ABSTRACT set_up_storage_reader(metadata, is_coordinator) 

Initialize this instance.


Parameters

  • metadata (Metadata) – The metadata schema to use.
  • is_coordinator ([bool]) – Whether this instance is responsible for coordinating the checkpoint.

Abstract Classmethod validate_checkpoint_id(checkpoint_id)

检查给定的 checkpoint_id 是否被存储支持。这允许我们启用自动存储选择。

返回类型:bool


class torch.distributed.checkpoint.StorageWriter

save_state_dict 用于写入存储的接口。

在分布式检查点中,一个 StorageWriter 实例同时充当协调者和跟随者角色。初始化时,每个实例都会被告知其角色。

子类应遵循以下调用顺序:

0、(所有进程)如果用户提供了有效的 checkpoint_id,则设置 checkpoint_id

1、(所有进程)调用 set_up_storage_writer()

2、(所有进程)调用 prepare_local_plan()

3、(协调者)调用 prepare_global_plan()

4、(所有进程)调用 write_data()

5、(协调者)调用 finish()


ABSTRACT  finish(metadata, results)

写入元数据并将当前检查点标记为成功。

用于序列化元数据的实际格式/模式是实现细节,唯一要求是能够还原为相同的对象图。

参数

  • metadata (Metadata) – 新检查点的元数据
  • results (list[list[torch.distributed.checkpoint.storage.WriteResult]]) – 来自所有进程的WriteResults列表

返回值:无

返回类型:无


ABSTRACT  prepare_global_plan(plans)

执行存储的集中规划。

此方法仅在协调器实例上调用。

虽然该方法可以生成完全不同的规划方案,但推荐的方式是将存储特定数据保存在 SavePlan::storage_data 中。

参数

  • plans (list[[torch.distributed.checkpoint.planner.SavePlan](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlan "torch.distributed.checkpoint.planner.SavePlan")]) – 一个包含各rank对应SavePlan实例的列表。

返回
经过存储全局规划处理后的SavePlan列表

返回类型
list[torch.distributed.checkpoint.planner.SavePlan]


ABSTRACT  prepare_local_plan(plan)

执行存储特定的本地规划。

虽然此方法可以生成完全不同的计划,但推荐的方式是将存储特定数据保存在 SavePlan::storage_data 中。

参数

  • plan ([SavePlan](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlan "torch.distributed.checkpoint.SavePlan")) – 当前使用的 SavePlanner 生成的本地计划。

返回

经过存储本地规划转换后的 SavePlan

返回类型

SavePlan


ABSTRACT  reset(checkpoint_id=None)

调用表示即将开始一次全新的检查点写入。

如果用户为本次检查点写入设置了checkpoint_id,则该参数可能存在。checkpoint_id的具体含义取决于存储实现,可能是指向文件夹/文件的路径,也可能是键值存储中的键名。

参数说明

  • checkpoint_id (Union[str, os.PathLike, None]) - 本次检查点实例的ID。checkpoint_id的具体含义取决于存储类型:
    • 对于文件系统存储,可以是文件夹路径或文件路径
    • 对于键值存储,可以是键名
      (默认值:None

ABSTRACT  set_up_storage_writer(is_coordinator)

初始化该实例。

参数

  • is_coordinator ([bool]) – 该实例是否负责协调检查点。

storage_meta()

返回存储特定的元数据。该方法用于在检查点中存储额外信息,这些信息有助于提供请求级别的可观测性。在保存调用期间,StorageMeta会被传递给SavePlanner。默认返回None。

TODO: 提供一个示例

返回类型:Optional[StorageMeta]


ABSTRACT classmethod* validate_checkpoint_id(checkpoint_id)

检查给定的 checkpoint_id 是否被存储系统支持。这让我们能够启用自动存储选择功能。

返回类型:bool


ABSTRACT  write_data(plan, planner)

使用 planner 解析数据,将 plan 中的所有条目写入。

子类应对计划中的每个条目调用 SavePlanner::resolve_data 方法,以获取待写入的底层对象访问权限。子类应惰性调用 resolve_data,因为该方法可能涉及内存分配。

对于张量,需遵循以下假设:

  • 张量可能位于任意设备上(包括与 WriteItem::tensor_data 设备不匹配的情况)
  • 张量可能是视图或非连续的,仅需保存其投影部分

参数

  • plan ([SavePlan](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlan "torch.distributed.checkpoint.SavePlan")) – 要执行的保存计划
  • planner ([SavePlanner](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlanner "torch.distributed.checkpoint.SavePlanner")) – 用于将条目解析为数据的规划器对象

返回值
一个最终返回 WriteResult 列表的 Future 对象

返回类型
Future [list [torch.distributed.checkpoint.storage.WriteResult]]

以下类型定义了检查点期间使用的规划器接口:

class torch.distributed.checkpoint.LoadPlanner

定义加载状态字典(load_state_dict)所用协议的抽象基类。

LoadPlanner是有状态对象,可用于自定义整个加载流程。它作为访问state_dict的代理,任何对字典的修改都会在整个流程中可见。

load_state_dict执行期间,规划器子类会按以下顺序接收调用:

1、set_up_planner - 所有rank节点都会调用。标志检查点加载开始
2、create_local_plan - 所有rank节点调用。处理state_dict并生成将用于全局规划的LoadPlan
3、create_global_plan - 仅协调者rank节点调用。汇总各rank的LoadPlan并做出全局决策
4、load_bytes - 每个rank节点会多次调用。对应state_dict中每个非张量值调用一次
5、resolve_tensorcommit_tensor - 每个rank节点成对调用。对应state_dict中每个张量值调用

建议用户继承DefaultLoadPlanner而非直接实现此接口,因为多数修改只需重写单个方法即可实现。

扩展规划器通常有两种模式:

重写state_dict。这是扩展加载流程最简单的方式,因为不需要理解LoadPlan的内部机制。由于加载是原地(in-place)操作,我们需要保留原始state_dict的引用,以便执行原地修改。


>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self, >>        state_dict: STATE_DICT_TYPE, >>        metadata: Metadata, >>        is_coordinator: bool, >>    ) -None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>> >
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>> >
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>> >
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>> >
>>>     def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

修改 resolve_tensorcommit_tensor 方法以支持加载时转换。


>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>> >
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor

ABSTRACT  commit_tensor(read_item, tensor)

StorageReader完成将数据加载到tensor后调用一次。

提供的tensor与调用resolve_tensor返回的是同一个。

仅当该LoadPlanner需要在将tensor复制回state_dict之前进行后处理时,才需要此方法。

tensor的内容将遵循其设备同步模型。


ABSTRACT  create_global_plan(global_plan)

计算全局加载计划并返回每个rank的加载计划。

注意:此方法仅在协调器rank上调用。

返回类型:list [torch.distributed.checkpoint.planner.LoadPlan]


ABSTRACT  create_local_plan()

基于set_up_planner提供的state_dict和元数据创建加载计划。

注意:此方法会在每个rank上调用。

返回类型:LoadPlan


ABSTRACT  finish_plan(central_plan)

接受协调器的计划并返回最终的加载方案。

返回类型:LoadPlan


ABSTRACT  load_bytes(read_item, value)

加载由 read_itemvalue 描述的项。

该方法预期会就地修改底层的 state_dict。

value 的内容由用于生成待加载检查点的 SavePlanner 定义。


resolve_bytes(read_item)

返回供 StorageReader 用于加载 read_item 的 BytesIO 对象。

该 BytesIO 应与底层 state_dict 中的对象建立别名关系,因为 StorageReader 会替换其内容。

返回类型:BytesIO


ABSTRACT  resolve_tensor(read_item)

返回由 read_item 描述的张量,供 StorageReader 用于加载 read_item。

该张量应与底层 state_dict 中的某个张量建立别名关系,因为 StorageReader 会替换其内容。

如果因任何原因无法实现这一点,规划器可以使用 commit_tensor 方法将数据复制回 state_dict 中的对应张量。

返回类型:Tensor


ABSTRACT  set_up_planner(state_dict, metadata=None, is_coordinator=False)

初始化该实例以将数据加载到 state_dict 中。

注意:此操作会在每个 rank 上调用。


class torch.distributed.checkpoint.LoadPlan(items:  list [[torch.distributed.checkpoint.planner.ReadItem](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.ReadItem "torch.distributed.checkpoint.planner.ReadItem")], storage_data: Any = None, planner_data: Any = None)

class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets:  torch.Size , storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets:  torch.Size , lengths:  torch.Size )

class torch.distributed.checkpoint.SavePlanner

定义保存状态字典(save_state_dict)所用协议的抽象类。

SavePlanner 是有状态对象,可用于自定义整个保存过程。它作为访问 state_dict 的代理,因此对其进行的任何转换都会对整个过程可见。

在 save_state_dict 过程中,规划器子类会按以下顺序调用方法:

1、set_up_planner - 在所有 rank 上调用。标志检查点保存开始
2、create_local_plan - 在所有 rank 上调用。处理 state_dict 并生成将用于全局规划的 SavePlan
3、create_global_plan - 仅在协调器 rank 上调用。汇总各 rank 的 SavePlan 并做出全局决策
4、finish_plan - 在所有 rank 上调用。使各 rank 能根据全局规划决策进行调整
5、resolve_data - 在每个 rank 上多次调用。为存储层查找 state_dict 中的值以供写入

建议用户直接继承 DefaultSavePlanner 而非本接口,因为大多数修改只需更改单个方法即可实现。

扩展通常有三种模式:

重写 state_dict。这是扩展保存过程最简单的方式,因为它不需要理解 SavePlan 的内部工作机制。


>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self, >>        state_dict: STATE_DICT_TYPE, >>        storage_meta: Optional[StorageMeta], >>        is_coordinator: bool, >>    ) -None:
>>> # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)

同步修改本地计划和查询。这在需要精细控制数据持久化方式时非常有用。


>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>> >
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤来制定无法由每个节点单独做出的中心化决策


>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         iters = [iter(all_plans[0].items)] * len(all_plans)
>>>         items_per_rank = [
>>>             [item for item in items if item is not None]
>>>             for items in zip(zip_longest(iters), strict=True)
>>>         ]
>>>         all_plans = [
>>>             replace(plan, items=items)
>>>             for plan, items in zip(all_plans, items_per_rank, strict=True)
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,某些规划器需要在检查点中保存额外的元数据。实现方式是让每个节点在本地计划中贡献其数据项,然后由全局规划器进行聚合:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>> >
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata

ABSTRACT  create_global_plan(all_plans)

计算全局检查点计划并返回每个rank的本地计划。

此方法仅在协调器rank上调用。

返回类型:tuple [list [torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata]


ABSTRACT  create_local_plan()

计算当前秩的保存计划。

该计划将被聚合并传递给create_global_plan

可以通过SavePlan::planner_data传递规划器特定数据。

此操作在所有秩上调用。

返回类型:SavePlan


ABSTRACT  finish_plan(new_plan)

create_local_plan 创建的规划与 create_global_plan 的结果进行合并。

此方法在所有进程上调用。

返回类型:SavePlan


ABSTRACT  resolve_data(write_item)

转换并准备来自 state_dictwrite_item 以进行存储,确保操作的幂等性和线程安全性。

在存储层处理之前,从 state_dict 中查找与 write_item 关联的对象,并应用任何转换(例如序列化)。

该方法会在每个 rank 上被多次调用,最终 SavePlan 中的每个 WriteItem 至少调用一次。

此方法应具备幂等性和线程安全性。StorageWriter 实现可以按需自由调用它。

为了减少检查点操作所需的内存峰值,任何涉及内存分配的转换都应延迟到调用该方法时执行。

返回张量时,它们可以位于任何设备或格式上,也可以是视图。存储层需自行确定如何保存它们。

返回类型:
Union [Tensor, BytesIO]


ABSTRACT  set_up_planner(state_dict, storage_meta=None, is_coordinator=False)

初始化此规划器以保存 state_dict

实现时应保存这些值,因为在后续保存过程中不会再次提供这些数据。

该操作会在所有节点上调用。


class torch.distributed.checkpoint.SavePlan(items:  list [[torch.distributed.checkpoint.planner.WriteItem](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.planner.WriteItem "torch.distributed.checkpoint.planner.WriteItem")], storage_data: Any = None, planner_data: Any = None, usable:  bool  = True)

class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)

这是一个数据类,用于保存需要写入存储的信息。


tensor_storage_size()

计算底层张量的存储大小,如果不是张量写入则返回 None。

返回值:Optional[int] 底层张量的存储大小(以字节为单位),如果存在的话。

返回类型:Optional[int]

我们提供了一个基于文件系统的存储层:

class torch.distributed.checkpoint.FileSystemReader(path, _extension_registry=None)

property checkpoint_id:  Union [str , PathLike] 

返回将用于加载检查点的 checkpoint_id。


class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True, _extensions=None)

使用文件IO实现StorageWriter的基础版本。

该实现基于以下假设和简化条件:

  • 检查点路径是一个空目录或不存在的目录
  • 文件创建操作是原子性的

每个检查点包含:每个写入请求对应一个文件,外加一个存储序列化元数据的.metadata文件。


stage(state_dict)

重写 AsyncStager.stage 方法

返回值类型:dict[str, Union[~StatefulT, Any]]

我们提供了 LoadPlanner 和 SavePlanner 的默认实现,能够处理所有 torch.distributed 结构,包括 FSDP、DDP、ShardedTensor 和 DistributedTensor。


class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False, enable_plan_caching=False)

lookup_object(index)

从规划器接口扩展,便于扩展默认规划器。

返回类型:任意


transform_object(write_item, object)

从规划器接口扩展而来,便于扩展默认规划器。


class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)

LoadPlanner基础上添加多项功能的DefaultLoadPlanner

具体新增以下特性:

  • flatten_state_dict:支持处理包含嵌套字典的state_dict
  • flatten_sharded_tensors:针对2D并行模式下的FSDP优化
  • allow_partial_load:若设为False,当state_dict中的键存在于检查点时会抛出运行时错误

lookup_tensor(index)

从规划器接口扩展而来,便于扩展默认规划器。

返回类型:Tensor


transform_tensor(read_item, tensor)

从规划器接口扩展而来,便于扩展默认规划器。

由于历史设计决策,FSDP和DDP的状态字典可能具有不同的键或完全限定名称(例如layer1.weight),即使原始未并行化的模型完全相同。此外,FSDP提供多种类型的模型状态字典,例如完整和分片状态字典。另外,优化器状态字典使用参数ID而非完全限定名称来标识参数,这在使用并行技术(如流水线并行)时可能导致问题。

为解决这些挑战,我们提供了一组API,方便用户管理状态字典。get_model_state_dict()返回的模型状态字典,其键与未并行化模型状态字典返回的键保持一致。类似地,get_optimizer_state_dict()提供的优化器状态字典,其键在所有应用的并行技术中保持统一。为实现这种一致性,get_optimizer_state_dict()将参数ID转换为与未并行化模型状态字典中完全相同的完全限定名称。

请注意,这些API返回的结果可直接与torch.distributed.checkpoint.save()torch.distributed.checkpoint.load()方法配合使用,无需任何额外转换。

set_model_state_dict()set_optimizer_state_dict()用于加载由各自getter API生成的模型和优化器状态字典。

请注意,set_optimizer_state_dict()只能在优化器调用backward()之前或step()之后调用。

请注意,此功能为实验性质,未来API签名可能会发生变化。


torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)

返回模型的状态字典(state_dict)和优化器的状态字典。

get_state_dict 能够处理任何通过 PyTorch 并行化的模块,包括 FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及这些并行方式的任意组合。get_state_dict 的主要功能包括:

1、返回一个模型和优化器的状态字典,该字典可以在不同数量的训练器和/或不同并行方式下重新分片。
2、隐藏并行化特定的状态字典 API。用户无需调用这些 API。
3、对结果状态字典进行完整性检查。

结果状态字典的键是规范的完全限定名称(FQN)。规范的 FQN 指的是基于参数在 nn.Module 层次结构中的位置生成的 FQN。更具体地说,参数的规范 FQN 是当模块未被任何并行化方式分发时,通过 module.named_parameters()module.named_buffers() 返回的 FQN。由于优化器内部使用参数 ID 来表示参数,调用此 API 时会将参数 ID 转换为规范的 FQN。

get_state_dict 也可以处理未并行化的模块。在这种情况下,get_state_dict 仅执行一项功能——将优化器的参数 ID 转换为规范的 FQN。

示例


>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict

>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)

>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
...     fsdp_model, fsdp_optim
... )

>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >># the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict

参数

  • model (nn.Module) - 需要获取状态字典的神经网络模型。
  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) - 用于优化model的优化器集合。
  • submodules (已弃用) - Optional[set[nn.Module]]: 仅返回属于指定子模块的模型参数。
  • options (StateDictOptions) - 控制如何返回模型状态字典和优化器状态字典的配置选项。详情参见StateDictOptions。

返回值:包含模型状态字典和优化器状态字典的Tuple元组。

返回类型:Tuple[Dict[str, ValueType], OptimizerStateType]


torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)

返回模型的model状态字典。

详细用法请参阅get_state_dict

参数

  • model (nn.Module) – 需要获取状态字典的nn.Module模型。
  • submodules (已弃用) – Optional[set[nn.Module]]: 仅返回属于指定子模块的模型参数。
  • options (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情参见StateDictOptions。

返回值:model的状态字典。

返回类型:Dict[str, ValueType]


torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)

返回优化器的组合状态字典。

有关详细用法,请参阅 get_state_dict

参数

  • model (nn.Module) – 用于模型的 nn.Module。
  • optimizers (Union[None*,* Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。
  • submodules (已弃用) – Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。
  • options (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参阅 StateDictOptions。

返回值:optimizers 的状态字典。

返回类型:OptimizerStateType


torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)

加载模型状态字典(state_dict)和优化器状态字典。

这是与 get_state_dict 相对应的操作,用于将状态字典设置到模型和优化器中。给定的 model_state_dictoptim_state_dict 不必由 get_state_dict 返回,但必须满足以下要求:

  1. 所有 FQN(完全限定名)必须符合 get_state_dict 中定义的规范格式;
  2. 如果张量是分片的,则必须是 ShardedTensor 或 DTensor 类型;
  3. 优化器状态字典不能包含参数 ID,其键应为规范化的 FQN。

警告:set_state_dict 只能在调用 backward() 之前或优化器执行 step() 之后调用,否则优化器状态将无法正确初始化。

参数

  • model (nn.Module) – 目标模型(nn.Module 实例)。
  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器(单个或可迭代集合)。
  • model_state_dict (Dict[str, ValueType]) – (联合类型 [Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):要加载的模型状态字典。若 model_state_dict 的键为 nn.Module,则该键是 model 的子模块,其值应为该子模块的状态字典。加载时会将子模块前缀自动附加到状态字典键名。
  • optim_state_dict (OptimizerStateType) – 要加载的优化器状态字典(OptimizerStateType 类型)。
  • options (StateDictOptions) – 控制模型和优化器状态字典加载方式的选项,详见 StateDictOptions 说明。

返回值

  • missing_keys:字符串列表,包含模型状态字典中缺失的键。
  • unexpected_keys:字符串列表,包含模型状态字典中意外的键。

返回类型:包含 missing_keysunexpected_keys 字段的命名元组(NamedTuple)


torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)

加载模型的状态字典(state_dict)。

这是get_model_state_dict的对应方法,用于将状态字典设置到模型上。详细用法请参考set_state_dict

参数

  • model (nn.Module) - 需要加载状态字典的nn.Module模型
  • model_state_dict Dict[str, ValueType]) - (Dict[str, ValueType]): 要加载的模型状态字典。如果model_state_dict的键是nn.Module类型,则该键是model的子模块,对应的值应该是该子模块的状态字典。加载时会将子模块的前缀附加到状态字典上。
  • options (StateDictOptions) - 控制如何加载模型状态字典和优化器状态字典的选项。详情请参阅StateDictOptions。

返回值

  • missing_keys 包含缺失键的字符串列表
  • unexpected_keys 包含意外键的字符串列表

返回类型:带有missing_keysunexpected_keys字段的NamedTuple


torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)

加载优化器的状态字典。

这是get_optimizer_state_dict的对应方法,用于将状态字典设置到优化器中。具体用法请参考set_state_dict

警告:set_optimizer_state_dict只能在优化器调用backward()之前或调用step()之后执行。否则,优化器状态将无法正确初始化。

参数

  • model (nn.Module) – 要操作的nn.Module模型。
  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化model的优化器或优化器集合。
  • optim_state_dict (OptimizerStateType) – OptimizerStateType类型:要加载的优化器状态字典。
  • options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。详情请参阅StateDictOptions。

返回值:无

返回类型:无


class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False, dsd_fqn_modifiers='_fqn_modifiers')

该数据类规定了 get_state_dict/set_state_dict 的工作机制:

  • full_state_dict:若设为 True,返回的 state_dict 中将收集所有张量,不会包含任何分片张量(ShardedTensor)或分布式张量(DTensor)。
  • cpu_offload:将所有张量卸载到 CPU。为防止 CPU 内存溢出(OOM),若同时启用 full_state_dict,则仅 rank0 会获取完整 state_dict,其他 rank 将获得空字典。
  • ignore_frozen_params:若为 True,返回的 state_dict 将排除所有冻结参数(即 requires_grad 为 False 的参数),默认值为 False。
  • keep_submodule_prefixes(已弃用):当指定 submodules 时,此选项决定是否保留 state_dict 键名中的子模块前缀。例如:若子模块为 module.pretrain 且参数完整限定名(FQN)为 pretrain.layer1.weight,启用该选项时返回的 state_dict 键名将保持为 pretrain.layer1.weight,禁用时则简化为 layer1.weight
    ⚠️ 注意:若禁用 keep_submodule_prefixes 可能导致 FQN 冲突,因此 submodules 应仅包含单个子模块。
  • strict:控制 set_state_dict 调用 model.load_state_dict() 时的严格模式。
  • broadcast_from_rank0:启用时,rank0 将接收完整 state_dict 并逐个广播其中的张量至其他 rank。其他 rank 会根据模型和优化器的本地分片情况接收并分片张量。使用此选项时必须启用 full_state_dict
    ⚠️ 当前仅支持 DTensor,不支持旧版 ShardedTensor。

针对习惯使用 torch.save 格式共享模型的用户,我们提供了以下离线工具方法用于格式转换:

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)

给定一个包含DCP检查点的目录,此函数会将其转换为Torch保存文件。

参数

  • dcp_checkpoint_dir ( Union [str,* PathLike]) - 包含DCP检查点的目录。
  • torch_save_path ( Union [str,* PathLike]) - 用于存储转换后的Torch保存文件的文件名。

警告:为避免内存不足(OOM),建议仅在单个rank上运行此函数。


torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)

给定 Torch 保存文件的位置,将其转换为 DCP 检查点。


参数

  • torch_save_path ( Union [str,* PathLike]) – Torch 保存文件的文件名。
  • dcp_checkpoint_dir ( Union [str,* PathLike]) – 存储 DCP 检查点的目录。

警告:为避免内存不足(OOM),建议仅在单个 rank 上运行此函数。

以下类也可用于从 torch.save 格式在线加载和重新分片模型。


class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)

StorageReader 用于读取 Torch 保存文件。该读取器会在协调器节点上读取整个检查点,然后将每个张量广播并分片到所有节点。

注意:需与 DynamicMetaLoadPlanner 配合使用。


警告:当前实现仅支持加载张量。


>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd, >>   storage_reader=BroadcastingTorchSaveReader(), >>   planner=DynamicMetaLoadPlanner(), >>   checkpoint_id="path_to_model.pt"
>>> )

prepare_global_plan(global_plan)

StorageReader 方法的实现

返回值类型:list [torch.distributed.checkpoint.planner.LoadPlan]


prepare_local_plan(plan)

StorageReader 方法的实现

返回类型:LoadPlan


read_data(plan, planner)

在协调器(coordinator)节点上读取 torch 保存的数据,随后进行广播

这会带来通信开销,但避免了在每个节点上加载完整检查点的需求,有望防止内存溢出(OOM)问题

返回类型:Future [None]


read_metadata()

扩展默认的 StorageReader 以支持构建元数据文件

返回类型:Metadata


reset(checkpoint_id=None)

StorageReader 方法的实现


set_up_storage_reader(metadata, is_coordinator)

StorageReader 方法的实现


CLASSMETHOD validate_checkpoint_id(checkpoint_id)

StorageReader 方法的实现

返回类型:bool


class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)

DefaultLoadPlanner的扩展实现,它会根据传入的状态字典创建新的元数据对象,从而避免从磁盘读取元数据的开销。这在读取没有独立元数据文件的格式(如Torch保存文件)时非常有用。

注意:该实现需与BroadcastingTorchSaveReader配合使用。

警告:当前实现仅支持加载张量(Tensors)。


>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd, >>   storage_reader=BroadcastingTorchSaveReader(), >>   planner=DynamicMetaLoadPlanner(), >>   checkpoint_id="path_to_model.pt"
>>> )

set_up_planner(state_dict, metadata=None, is_coordinator=False)

以下是翻译结果:

规划器的设置,通过从状态字典创建元数据对象来扩展默认行为

以下实验性接口可用于提升生产环境中的可观测性:


概率分布 - torch.distributions

distributions 包包含可参数化的概率分布和采样函数。这使得构建随机计算图和用于优化的随机梯度估计器成为可能。该包总体上遵循 TensorFlow Distributions 包的设计理念。

无法直接通过随机样本进行反向传播。然而,有两种主要方法可以创建可反向传播的替代函数:评分函数估计器/似然比估计器/REINFORCE 和路径导数估计器。REINFORCE 通常被视为强化学习中策略梯度方法的基础,而路径导数估计器常见于变分自编码器的重参数化技巧中。评分函数仅需要样本值 f(x)f(x)f(x),而路径导数则需要导数 f′(x)f’(x)f′(x)。接下来的章节将通过强化学习示例讨论这两种方法。更多细节请参阅 使用随机计算图的梯度估计。


评分函数

当概率密度函数对其参数可微时,我们只需要使用 sample()log_prob() 即可实现 REINFORCE 算法:

Δθ=αr∂log⁡p(a∣πθ(s))∂θ\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}Δθ=αrθlogp(aπθ(s))

其中 θ\thetaθ 表示参数,α\alphaα 是学习率,rrr 代表奖励值,p(a∣πθ(s))p(a|\pi^\theta(s))p(aπθ(s)) 表示在状态 sss 下根据策略 πθ\pi^\thetaπθ 采取行动 aaa 的概率。

实际应用中,我们会从网络输出中采样一个动作,在环境中执行该动作,然后使用 log_prob 构建等效的损失函数。注意这里使用负号是因为优化器采用梯度下降法,而上述规则假设的是梯度上升。对于分类策略,实现 REINFORCE 的代码如下:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

路径导数

另一种实现这些随机/策略梯度的方法是使用rsample()方法中的重参数化技巧。通过这种方式,参数化的随机变量可以转化为一个无参数随机变量的确定性函数。因此,重参数化后的样本变得可微分。以下是实现路径导数的代码示例:

params = policy_network(state)
m = Normal(params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

分发


class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)

基类:object

Distribution 是概率分布的抽象基类。


property arg_constraints:  dict[str , torch.distributions.constraints.Constraint] 

返回一个从参数名到Constraint对象的字典,该字典应满足此分布每个参数的要求。非张量类型的参数无需出现在此字典中。


property batch_shape: Size 

返回参数批处理所应用的形状。


cdf(value)

返回在给定值处评估的累积密度/质量函数。

参数

  • value ( Tensor )

返回类型 : Tensor


entropy()

返回在 batch_shape 上批处理的分布熵。

返回值:形状为 batch_shape 的张量。

返回类型:Tensor


enumerate_support(expand=True)

返回包含离散分布所有可能取值的张量。结果将沿着第0维度进行枚举,因此输出形状为:(基数,) + 批次形状 + 事件形状(对于单变量分布,事件形状=())。

需注意:该方法会以同步锁步方式枚举所有批处理张量,例如[[0,0], [1,1], …]。当expand=False时,枚举仅沿第0维度进行,其余批次维度保持单一维度,形如[[0], [1], …]。

若要遍历完整的笛卡尔积,请使用itertools.product(m.enumerate_support())。

参数说明:

  • expand ([bool]) - 控制是否沿批次维度扩展支持集以匹配分布的batch_shape

返回值:
沿第0维度迭代的张量

返回类型:Tensor


property event_shape:  Size 

返回单个样本的形状(不包含批处理)。


expand(batch_shape, _instance=None)

返回一个新的分布实例(或填充由派生类提供的现有实例),并将批次维度扩展为batch_shape。该方法会在分布的参数上调用expand。因此,扩展后的分布实例不会分配新的内存。此外,当首次创建实例时,不会重复执行__init__.py中的任何参数检查或参数广播操作。

参数

  • batch_shape ( torch.Size ) – 期望扩展的尺寸。
  • _instance – 需要覆盖.expand方法的子类提供的新实例。

返回值:批次维度扩展至batch_size的新分布实例。


icdf(value)

返回在给定值处评估的逆累积密度/质量函数。

参数

  • value ( Tensor )

返回类型 : Tensor


log_prob(value)

返回在给定值处评估的概率密度/质量函数的对数。

参数

  • value ( Tensor )

返回类型 : Tensor


property mean:  Tensor 

返回该分布的均值。


property mode:  Tensor 

返回该分布的众数。


perplexity()

返回在 batch_shape 上批处理的分布困惑度。

返回值:形状为 batch_shape 的张量。

返回类型:Tensor


rsample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的重参数化样本,或者当分布参数为批处理时,生成形状为 sample_shape 的批量重参数化样本。

返回类型:Tensor


sample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的样本,如果分布参数是批处理的,则生成形状为 sample_shape 的批量样本。

返回类型:Tensor


sample_n(n)

生成 n 个样本,如果分布参数是批处理的,则生成 n 批样本。

返回类型:Tensor


static set_default_validate_args(value)

设置是否启用验证功能。

默认行为模仿 Python 的 assert 语句:验证功能默认开启,但如果 Python 以优化模式运行(通过 python -O 命令)则会自动关闭。由于验证过程可能消耗较多资源,当模型运行稳定后可以考虑禁用此功能。

参数说明

  • value ([bool]) – 控制是否启用验证的布尔值。

property stddev:  Tensor 

返回该分布的标准差。


property support: Optional[Constraint] 

返回一个表示该分布支撑集的 Constraint 对象。


property variance:  Tensor 

返回该分布的方差。


指数族分布


class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)

基类:Distribution

ExponentialFamily 是指数族概率分布的抽象基类,其概率质量/密度函数定义如下:

pF(x;θ)=exp⁡(⟨t(x),θ⟩−F(θ)+k(x))p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))pF(x;θ)=exp(⟨t(x),θF(θ)+k(x))

其中 θ\thetaθ 表示自然参数,t(x)t(x)t(x) 表示充分统计量,F(θ)F(\theta)F(θ) 是该族的对数归一化函数,k(x)k(x)k(x) 为载体测度。

说明:该类是 Distribution 类与属于指数族的分布之间的中间层,主要用于验证 .entropy() 和解析 KL 散度方法的正确性。我们利用该类通过自动微分框架和 Bregman 散度来计算熵与 KL 散度(基于 Frank Nielsen 和 Richard Nock 的研究成果《指数族的熵与交叉熵》)。


entropy()

通过计算对数归一化器的Bregman散度来计算熵的方法。


伯努利


class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)

基础分布:ExponentialFamily

创建一个由 probslogits 参数化的伯努利分布(但不可同时使用两者)。

样本为二元值(0 或 1)。以概率 p 取值为 1,以概率 1 - p 取值为 0。


示例:

>>> m = Bernoulli(torch.tensor([0.3]))
>>> m.sample()  # 30% chance 1; 70% chance 0
tensor([0.])

参数

  • probs (Number*,* Tensor ) – 采样结果为1的概率
  • logits (Number*,* Tensor ) – 采样结果为1的对数几率

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)

has_enumerate_support = True

log_prob(value)

property logits:  Tensor 

property mean:  Tensor 

property mode:  Tensor 

property param_shape:  Size 

property probs:  Tensor 

sample(sample_shape=torch.Size([]))

support = Boolean()

property variance:  Tensor 

Beta


class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)

基类:ExponentialFamily

concentration1concentration0 参数化的 Beta 分布。


示例:

>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
>>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
tensor([0.1046])

参数

  • concentration1 (float 或 Tensor) - 分布的第一个浓度参数(通常称为 alpha)
  • concentration0 (float 或 Tensor) - 分布的第二个浓度参数(通常称为 beta)

arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}

property concentration0:  Tensor 

property concentration1:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 
rsample(sample_shape=())

Return type : Tensor


support = Interval(lower_bound=0.0, upper_bound=1.0)

property variance: Tensor

Binomial


class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a Binomial distribution parameterized by total_count and either probs or logits (but not both). total_count must be broadcastable with probs/logits.


Example:


>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))>>> x = m.sample()tensor([ 0., 22., 71., 100.])>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))>>> x = m.sample()tensor([[4., 5.], [7., 6.]])

Parameters

  • total_count ( int or Tensor ) – number of Bernoulli trials
  • probs ( Tensor ) – Event probabilities
  • logits ( Tensor ) – Event log-odds

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}


entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)

has_enumerate_support = True

log_prob(value)

property logits: Tensor

property mean: Tensor

property mode: Tensor

property param_shape: Size


property probs:  Tensor
***

sample(sample_shape=torch.Size([]))

property support

Return type : _DependentProperty


property variance: Tensor

Categorical


class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a categorical distribution parameterized by either probs or logits (but not both).


Note: It is equivalent to the distribution that torch.multinomial()
samples from.

Samples are integers from {0,…,K−1}\{0, \ldots, K-1\}{0,…,K−1} where K is probs.size(-1).

If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.

If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.


Note: The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. probs
will return this normalized value.
The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. logits
will return this normalized value.

See also: torch.multinomial()


Example:

>>> m = Categorical(torch.tensor([0.25, 0.25, 0.25, 0.25 ]))>>> m.sample()  # 0, 1, 2, 3 的采样概率均等tensor(3)

Parameters

  • probs ( Tensor ) – event probabilities
  • logits ( Tensor ) – event log probabilities (unnormalized)

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)

has_enumerate_support = True

log_prob(value)

property logits: Tensor

property mean: Tensor

property mode:  Tensor

property param_shape: Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

property support

Return type : _DependentProperty


property variance: Tensor

Cauchy


class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)

Bases: Distribution

Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means 0 follows a Cauchy distribution.


Example:

>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))>>> m.sample()  # 从位置参数为0、尺度参数为1的柯西分布中采样tensor([2.3214])

Parameters

  • loc (float or Tensor ) – mode or median of the distribution.
  • scale (float or Tensor ) – half width at half maximum.

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

Return type : Tensor


support = Real()

property variance: Tensor
***

Chi2


class torch.distributions.chi2.Chi2(df, validate_args=None)

Bases: Gamma

Creates a Chi-squared distribution parameterized by shape parameter df.
This is exactly equivalent to Gamma(alpha=0.5*df, beta=0.5)


Example:

>>> m = Chi2(torch.tensor([1.0]))>>> m.sample()  # 自由度为1的卡方分布抽样tensor([0.1046])

Parameters

  • df (float or Tensor ) – shape parameter of the distribution

arg_constraints = {'df': GreaterThan(lower_bound=0.0)}

property df: Tensor

expand(batch_shape, _instance=None)

连续伯努利分布


class torch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)

基类:ExponentialFamily

创建一个由 probslogits 参数化的连续伯努利分布(两者不可同时使用)。

该分布的支持区间为 [0, 1],可通过 ‘probs’(取值在 (0,1) 区间)或 ‘logits’(实数)进行参数化。需要注意的是,与伯努利分布不同,这里的 ‘probs’ 并不对应概率,‘logits’ 也不对应对数几率,但由于与伯努利分布的相似性而沿用了相同名称。更多细节请参阅文献 [1]。


示例:

>>> m = ContinuousBernoulli(torch.tensor([0.3]))
>>> m.sample()
tensor([0.2538])

参数

  • probs (Number*,* Tensor ) – 取值范围在(0,1)之间的参数
  • logits (Number*,* Tensor ) – 实数参数,其sigmoid值匹配’probs’

[1] 连续伯努利分布:修正变分自编码器中的一个普遍错误,Loaiza-Ganem G 和 Cunningham JP,NeurIPS 2019。https://arxiv.org/abs/1907.06845

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

cdf(value)
entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)
log_prob(value)

property logits: Tensor

(注:根据核心翻译原则第1条,代码块内容保持原样不翻译)


property mean: Tensor

property param_shape:  Size

property probs:  Tensor 

rsample(sample_shape=torch.Size([]))

Return type : Tensor


sample(sample_shape=torch.Size([]))

property stddev: Tensor


support = Interval(lower_bound=0.0, upper_bound=1.0)

property variance:  Tensor 

狄利克雷


class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)

基类:ExponentialFamily

创建一个由浓度参数 concentration 参数化的狄利克雷分布。


示例:

>>> m = Dirichlet(torch.tensor([0.5, 0.5]))
>>> m.sample()  # Dirichlet distributed with concentration [0.5, 0.5]
tensor([0.1046, 0.8954])

参数

  • concentration ( Tensor ) - 分布的浓度参数(通常称为 alpha)

arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1)}

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=())

返回类型:Tensor


support = Simplex()

property variance:  Tensor 

指数函数


class torch.distributions.exponential.Exponential(rate, validate_args=None)

基类:ExponentialFamily

创建一个由 rate 参数化的指数分布。


示例:

>>> m = Exponential(torch.tensor([1.0]))
>>> m.sample()  # Exponential distributed with rate=1
tensor([0.1046])

参数

  • rate (float 或 Tensor) – 该分布的 rate = 1 / scale

arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


property stddev:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance:  Tensor 

费希尔-斯涅克分布


class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)

基类:Distribution

创建一个由 df1df2 参数化的 Fisher-Snedecor 分布。


示例:

>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
tensor([0.2453])

参数

  • df1 (float 或 Tensor) – 自由度参数1
  • df2 (float 或 Tensor) – 自由度参数2

arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}expand(batch_shape, _instance=None)

has_rsample = True

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = GreaterThan(lower_bound=0.0)

property variance:  Tensor 

Gamma


class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)

基类:ExponentialFamily

创建一个由形状参数 concentration 和比率参数 rate 参数化的 Gamma 分布。


示例:

>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # Gamma distributed with concentration=1 and rate=1
tensor([0.1046])

参数

  • concentration (float 或 Tensor) - 分布的形状参数(通常称为 alpha)
  • rate (float 或 Tensor) - 分布的速率参数(通常称为 beta),rate = 1 / scale

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = GreaterThanEq(lower_bound=0.0)

property variance:  Tensor 

几何


class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)

基类:Distribution

创建一个由 probs 参数化的几何分布,其中 probs 表示伯努利试验的成功概率。

概率质量函数为:
P(X=k)=(1−p)kp,k=0,1,…P(X=k) = (1-p)^{k} p, k = 0, 1,
…P(X=k)=(1−p)kp,k=0,1,…

注意:
torch.distributions.geometric.Geometric() 将第 (k+1)(k+1)(k+1) 次试验视为首次成功,因此采样范围为 {0,1,…}\{0, 1, \ldots\}{0,1,…};
torch.Tensor.geometric_() 将第 k 次试验视为首次成功,因此采样范围为 {1,2,…}\{1, 2, \ldots\}{1,2,…}。


示例:

>>> m = Geometric(torch.tensor([0.3]))
>>> m.sample()  # underlying Bernoulli has 30% chance 1; 70% chance 0
tensor([2.])

参数

  • probs (Number*,* Tensor ) – 采样结果为1的概率值,必须在(0, 1]范围内
  • logits (Number*,* Tensor ) – 采样结果为1的对数几率值

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)} 

entropy()

expand(batch_shape, _instance=None)

log_prob(value)

property logits:  Tensor 

property mean:  Tensor 

property mode:  Tensor 

property probs:  Tensor 

sample(sample_shape=torch.Size([]))

support = IntegerGreaterThan(lower_bound=0)

property variance:  Tensor 

gumbel


class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)

基类:TransformedDistribution

从Gumbel分布中采样。


示例:

>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # sample from Gumbel distribution with loc=1, scale=2
tensor([1.0124])

参数

  • loc (float 或 Tensor) - 分布的位置参数
  • scale (float 或 Tensor) - 分布的尺度参数

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)} 

entropy()

expand(batch_shape, _instance=None)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property stddev:  Tensor 

support = Real()

property variance: Tensor

HalfCauchy


class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)

Bases: TransformedDistribution

Creates a half-Cauchy distribution parameterized by scale where:

X ~ Cauchy(0, scale)
Y = |X| ~ HalfCauchy(scale)

Example:

>>> m = HalfCauchy(torch.tensor([1.0]))>>> m.sample()  # 从scale=1的半柯西分布中采样tensor([2.3214])

Parameters

  • scale (float or Tensor ) – scale of the full Cauchy distribution

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()


expand(batch_shape, _instance=None)

has_rsample = True


icdf(prob)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property scale:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance: Tensor

HalfNormal

class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)

Bases: TransformedDistribution

Creates a half-normal distribution parameterized by scale where:

X ~ Normal(0, scale)
Y = |X| ~ HalfNormal(scale)

Example:

>>> m = HalfNormal(torch.tensor([1.0]))>>> m.sample()  # 从scale=1的半正态分布中采样tensor([0.1046])

Parameters

  • scale (float or Tensor ) – scale of the full Normal distribution

arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(prob)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property scale:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance:  Tensor 

独立


class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)

基类:Distribution

将分布的部分批次维度重新解释为事件维度。

这一功能主要用于改变 log_prob() 返回结果的形状。例如,若想创建一个与多元正态分布形状相同的对角正态分布(使二者可互换),您可以:

>>> from torch.distributions.multivariate_normal import MultivariateNormal
>>> from torch.distributions.normal import Normal
>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size([]), torch.Size([3])]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size([3]), torch.Size([])]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size([]), torch.Size([3])]

参数

  • base_distribution (torch.distributions.distribution.Distribution) – 基础分布
  • reinterpreted_batch_ndims ( int ) – 将被重新解释为事件维度的批次维度数量

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {} 

entropy()

enumerate_support(expand=True)


expand(batch_shape, _instance=None)

property has_enumerate_support:  bool 

property has_rsample:  bool 

log_prob(value)

property mean: Tensor

property mode: Tensor

rsample(sample_shape=torch.Size([]))

Return type : Tensor


sample(sample_shape=torch.Size([]))

Return type : Tensor


property support

Return type : _DependentProperty


property variance: Tensor

InverseGamma



class torch.distributions.inverse_gamma.InverseGamma(concentration, rate, validate_args=None)

Bases: TransformedDistribution

Creates an inverse gamma distribution parameterized by concentration and rate
where:

X ~ Gamma(concentration, rate)
Y = 1 / X ~ InverseGamma(concentration, rate)

Example:

>>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))>>> m.sample()tensor([1.2953])

Parameters

  • concentration (float or Tensor ) – shape parameter of the distribution
    (often referred to as alpha)
  • rate (float or Tensor ) – rate = 1 / scale of the distribution
    (often referred to as beta)

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}

property concentration:  Tensor

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

property mean:  Tensor 

property mode:  Tensor 

property rate:  Tensor 

支持范围 = GreaterThan(下限=0.0)


property variance: Tensor

Kumaraswamy


class torch.distributions.kumaraswamy.Kumaraswamy(concentration1, concentration0, validate_args=None)

Bases: TransformedDistribution

Samples from a Kumaraswamy distribution.


Example:

>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))>>> m.sample()  # 从 alpha=1 和 beta=1 的 Kumaraswamy 分布中采样tensor([0.1729])

Parameters

  • concentration1 (float or Tensor ) – 1st concentration parameter of the distribution
    (often referred to as alpha)
  • concentration0 (float or Tensor ) – 2nd concentration parameter of the distribution
    (often referred to as beta)

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

property mean: Tensor

property mode: Tensor


support = Interval(lower_bound=0.0, upper_bound=1.0)

property variance:  Tensor 

LKJCholesky


class torch.distributions.lkj_cholesky.LKJCholesky(dim, concentration=1.0, validate_args=None)

基类:Distribution

LKJ分布用于描述相关矩阵的下三角Cholesky因子。

该分布由浓度参数η(concentration)控制,使得从Cholesky因子生成的相关矩阵M的概率与det(M)^{η-1}成正比。因此,当concentration == 1时,我们得到相关矩阵Cholesky因子的均匀分布。


L ~ LKJCholesky(dim, concentration)
X = L @ L' ~ LKJCorr(dim, concentration)

请注意,该分布是对相关矩阵的Cholesky因子进行采样,而非直接对相关矩阵本身采样,因此与文献[1]中关于LKJCorr分布的推导略有不同。在采样过程中,这里采用了文献[1]第3节所述的Onion方法。


示例:

>>> l = LKJCholesky(3, 0.5)
>>> l.sample()  # l @ l.T is a sample of a correlation 3x3 matrix
tensor([[1.0000, 0.0000, 0.0000], [0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]])

参数

  • dimension (dim) – 矩阵的维度
  • concentration (float 或 Tensor) – 分布的形状参数/浓度参数(通常称为 eta)

参考文献

[1] Generating random correlation matrices based on vines and extended onion method (2009), Daniel Lewandowski, Dorota Kurowicka, Harry Joe.

Journal of Multivariate Analysis. 100、10.1016/j.jmva.2009.04.008


arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}expand(batch_shape, _instance=None)

log_prob(value)

sample(sample_shape=torch.Size([]))

support = CorrCholesky()

Laplace


class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)

Bases: Distribution

Creates a Laplace distribution parameterized by loc and scale.


Example:

>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))>>> m.sample()  # 服从拉普拉斯分布,位置参数=0,尺度参数=1tensor([0.1046])

Parameters

  • loc (float or Tensor ) – mean of the distribution
  • scale (float or Tensor ) – scale of the distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 


rsample(sample_shape=torch.Size([]))

Return type : Tensor


property stddev:  Tensor

support = Real()

property variance:  Tensor 

对数正态分布


class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)

基类:TransformedDistribution

创建一个由locscale参数化的对数正态分布,其中:

***
X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)

Example:


>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # log-normal distributed with mean=0 and stddev=1
tensor([0.1046])

参数

  • loc (float 或 Tensor) - 分布对数的均值
  • scale (float 或 Tensor) - 分布对数的标准差

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

property loc:  Tensor 

property mean:  Tensor 

property mode:  Tensor 

property scale:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance: Tensor

LowRankMultivariateNormal


class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)

Bases: Distribution

Creates a multivariate normal distribution with covariance matrix having a low-rank form
parameterized by cov_factor and cov_diag:

covariance_matrix = cov_factor @ cov_factor.T + cov_diag

Example :

>>> m = LowRankMultivariateNormal(
...     torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2)
... )
>>> m.sample()  # 服从均值=`[0,0]`、协方差因子=`[[1],[0]]`、对角协方差=`[1,1]`的正态分布
tensor([-0.2102, -0.5429])

Parameters

  • loc ( Tensor ) – mean of the distribution with shape batch_shape + event_shape
  • cov_factor ( Tensor ) – factor part of low-rank form of covariance matrix with shape
    batch_shape + event_shape + (rank,)
  • cov_diag ( Tensor ) – diagonal part of low-rank form of covariance matrix with shape
    batch_shape + event_shape

Note: The computation for determinant and inverse of covariance matrix is avoided when
cov_factor.shape[1] << cov_factor.shape[0] thanks to Woodbury matrix identity and matrix determinant lemma.
Thanks to these formulas, we just need to compute the determinant and inverse of the small size “capacitance” matrix:

capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor

arg_constraints = {'cov_diag': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': IndependentConstraint(Real(), 1)}

property covariance_matrix:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property precision_matrix:  Tensor 

rsample(sample_shape=torch.Size([]))

Return type : Tensor


property scale_tril:  Tensor


support = IndependentConstraint(Real(), 1)

property variance:  Tensor 

混合相同族分布


class torch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution, component_distribution, validate_args=None)

基类:Distribution

MixtureSameFamily 分布实现了(批量)混合分布,其中所有组件都来自同一分布类型的不同参数化形式。它通过一个分类"选择分布"(覆盖k个组件)和一个组件分布进行参数化,其中组件分布是一个具有最右侧批量形状(等于[k])的Distribution,用于索引每个(批量的)组件。


示例:

>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
>>> # weighted normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
>>> gmm = MixtureSameFamily(mix, comp)>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
>>> # weighted bivariate normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Independent(D.Normal(
...          torch.randn(5,2), torch.rand(5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
>>> # consisting of 5 random weighted bivariate normal distributions
>>> mix = D.Categorical(torch.rand(3,5))
>>> comp = D.Independent(D.Normal(
...         torch.randn(3,5,2), torch.rand(3,5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

参数

  • mixture_distribution (Categorical) – 类似 torch.distributions.Categorical 的实例,用于管理选择组件的概率。类别数量必须与 component_distribution 最右侧的批次维度匹配。必须具有标量 batch_shape 或与 component_distribution.batch_shape[:-1] 匹配的 batch_shape。
  • component_distribution (Distribution) – 类似 torch.distributions.Distribution 的实例。最右侧的批次维度用于索引组件。

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {}

cdf(x)

property component_distribution: Distribution

expand(batch_shape, _instance=None)

has_rsample = False

log_prob(x)

property mean: Tensor

property mixture_distribution: Categorical


sample(sample_shape=torch.Size([]))

property support

Return type : _DependentProperty


property variance: Tensor

Multinomial


class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a Multinomial distribution parameterized by total_count and either probs or logits (but not both). The innermost dimension of probs indexes over categories. All other dimensions index over batches.

Note that total_count need not be specified if only log_prob() is called (see example below)


Note: The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. probs
will return this normalized value.
The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. logits
will return this normalized value.

  • sample() requires a single shared total_count for all
    parameters and samples.
  • log_prob() allows different total_count for each parameter and sample.

Example:

>>> m = Multinomial(100, torch.tensor([1., 1., 1., 1.]))>>> x = m.sample()  # 0, 1, 2, 3 的采样概率均等tensor([21., 24., 30., 25.])>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)tensor([-4.1338])

Parameters

  • total_count ( int ) – number of trials
  • probs ( Tensor ) – event probabilities
  • logits ( Tensor ) – event log probabilities (unnormalized)

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

entropy()

expand(batch_shape, _instance=None)

log_prob(value)

property logits: Tensor

property mean: Tensor

property param_shape:  Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

property support 

返回类型:_DependentProperty

total_count:int


property variance:  Tensor 

多元正态分布


class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)

基类:Distribution

创建一个由均值向量和协方差矩阵参数化的多元正态(也称为高斯)分布。

多元正态分布可以通过以下三种方式参数化:
1、正定协方差矩阵 Σ\mathbf{\Sigma}Σ
2、正定精度矩阵 Σ−1\mathbf{\Sigma}^{-1}Σ−1
3、具有正对角元素的下三角矩阵 L\mathbf{L}L(满足 Σ=LL⊤\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\topΣ=LL⊤)

该三角矩阵可以通过协方差矩阵的Cholesky分解等方法获得。


示例

>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
tensor([-0.2102, -0.5429])

参数

  • loc ( Tensor ) – 分布的均值
  • covariance_matrix ( Tensor ) – 正定协方差矩阵
  • precision_matrix ( Tensor ) – 正定精度矩阵
  • scale_tril ( Tensor ) – 协方差的下三角因子,对角线元素为正

注意:只能指定 covariance_matrixprecision_matrixscale_tril 中的一个参数。

使用 scale_tril 会更高效:所有内部计算都基于 scale_tril。如果传入的是 covariance_matrixprecision_matrix,则仅用于通过 Cholesky 分解计算对应的下三角矩阵。


arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': IndependentConstraint(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}

property covariance_matrix:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property precision_matrix:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


property scale_tril:  Tensor 

support = IndependentConstraint(Real(), 1)

property variance: Tensor

NegativeBinomial


class torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a Negative Binomial distribution, i.e. distribution of the number of successful independent and identical Bernoulli trials
before total_count failures are achieved. The probability of success of each Bernoulli trial is probs.


Parameters

  • total_count (float or Tensor ) – non-negative number of negative Bernoulli
    trials to stop, although the distribution is still valid for real
    valued count
  • probs ( Tensor ) – Event probabilities of success in the half open interval [0, 1)
  • logits ( Tensor ) – Event log-odds for probabilities of success

arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}

expand(batch_shape, _instance=None)


log_prob(value)

property logits: Tensor

property mean: Tensor

property mode: Tensor

property param_shape:  Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

support = IntegerGreaterThan(lower_bound=0)

property variance:  Tensor 

常规


class torch.distributions.normal.Normal(loc, scale, validate_args=None)

基类:ExponentialFamily

创建一个由locscale参数化的正态(也称为高斯)分布。


示例:

>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # normally distributed with loc=0 and scale=1
tensor([0.1046])

参数

  • loc (float 或 Tensor) - 分布的均值(通常称为 mu)
  • scale (float 或 Tensor) - 分布的标准差(通常称为 sigma)

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


sample(sample_shape=torch.Size([]))

property stddev:  Tensor 

support = Real()

property variance:  Tensor 

OneHotCategorical


class torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)

基类:Distribution

创建一个由 probslogits 参数化的 one-hot 分类分布。

样本是大小为 probs.size(-1) 的 one-hot 编码向量。

注意:probs 参数必须是非负、有限且具有非零和,它将在最后一个维度上被归一化为总和为 1。probs 将返回这个归一化后的值。

logits 参数将被解释为未归一化的对数概率,因此可以是任何实数。它同样会被归一化,使得最终概率在最后一个维度上总和为 1。logits 将返回这个归一化后的值。

另请参阅:torch.distributions.Categorical() 以了解 probslogits 的详细说明。


示例:

>>> m = OneHotCategorical(torch.tensor([0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
tensor([0., 0., 0., 1.])

参数

  • probs ( Tensor ) – 事件概率
  • logits ( Tensor ) – 事件对数概率(未归一化)

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)```

has_enumerate_support = True


log_prob(value)

property logits: Tensor


property mean:  Tensor

property mode: Tensor

property param_shape:  Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

support = OneHot()

property variance: Tensor

Pareto


class torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)

Bases: TransformedDistribution

Samples from a Pareto Type 1 distribution.


Example:

>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))>>> m.sample()  # 从scale=1且alpha=1的帕累托分布中采样tensor([1.5623])

Parameters

  • scale (float or Tensor ) – Scale parameter of the distribution
  • alpha (float or Tensor ) – Shape parameter of the distribution

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

Return type : Tensor

expand(batch_shape, _instance=None)

Return type
Pareto


property mean: Tensor

property mode:  Tensor

property support: Constraint

Return type : _DependentProperty


property variance: Tensor

Poisson


class torch.distributions.poisson.Poisson(rate, validate_args=None)

Bases: ExponentialFamily

Creates a Poisson distribution parameterized by rate, the rate parameter.

Samples are nonnegative integers, with a pmf given by
rateke−ratek!\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}

ratekk!e−rate​Example:

>>> m = Poisson(torch.tensor([4]))
>>> m.sample()
tensor([3.])

Parameters

  • rate (Number*,* Tensor ) – the rate parameter

arg_constraints = {'rate': GreaterThanEq(lower_bound=0.0)}expand(batch_shape, _instance=None)

log_prob(value)

property mean: Tensor

property mode: Tensor

sample(sample_shape=torch.Size([]))

support = IntegerGreaterThan(lower_bound=0) 

property variance:  Tensor 

松弛伯努利分布

(注:根据技术文档翻译原则,此处保留原英文术语"RelaxedBernoulli"作为专有名词不翻译,仅对标题层级和格式符号进行本地化处理。技术文档中常见的分布名称通常保留原文以确保准确性。)


class torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)

基类:TransformedDistribution

创建一个RelaxedBernoulli分布,参数化方式为temperature,以及probslogits(但不可同时使用)。这是伯努利分布的松弛版本,因此取值范围在(0, 1)之间,并且具有可重参数化的样本。


示例:

>>> m = RelaxedBernoulli(torch.tensor([2.2]), 
...                      torch.tensor([0.1, 0.2, 0.3, 0.99]))
>>> m.sample()
tensor([0.2951, 0.3442, 0.8918, 0.9021])

参数

  • temperature ( Tensor ) – 松弛温度
  • probs (Number*,* Tensor ) – 采样结果为1的概率
  • logits (Number*,* Tensor ) – 采样结果为1的对数几率

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

expand(batch_shape, _instance=None)

has_rsample = True

property logits:  Tensor 

property probs:  Tensor 

support = Interval(lower_bound=0.0, upper_bound=1.0)

property temperature:  Tensor 

LogitRelaxedBernoulli


class torch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)

基类:Distribution

创建一个由 probslogits(但不同时使用)参数化的 LogitRelaxedBernoulli 分布,这是 RelaxedBernoulli 分布的对数几率。

采样结果是 (0, 1) 区间值的对数几率。更多细节参见[1]。


参数

  • temperature ( Tensor ) – 松弛温度参数
  • probs (Number*,* Tensor ) – 采样结果为 1 的概率
  • logits (Number*,* Tensor ) – 采样结果为 1 的对数优势比

参考文献
[1] 《具体分布:离散随机变量的连续松弛方法》(Maddison 等人,2017)
[2] 《基于 Gumbel-Softmax 的类别重参数化方法》(Jang 等人,2017)

注:
1、保留所有代码块和链接原格式
2、技术术语如"logits"、“tensor"等保持英文
3、被动语态转为主动语态(如"parameterized by"译为"由…参数化的”)
4、数学符号区间(0,1)保留原格式
5、文献标题采用中文书名号并补充说明性文字"方法"


arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}expand(batch_shape, _instance=None)

log_prob(value)

property logits:  Tensor 

property param_shape:  Size 

property probs:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = Real()

RelaxedOneHotCategorical


class torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)

基类:TransformedDistribution

创建一个由 temperature 以及 probslogits 参数化的 RelaxedOneHotCategorical 分布。

这是 OneHotCategorical 分布的松弛版本,因此其样本位于单纯形上,并且可重新参数化。


示例:

>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), 
...                              torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([0.1294, 0.2324, 0.3859, 0.2523])

参数

  • temperature ( Tensor ) – 松弛温度
  • probs ( Tensor ) – 事件概率
  • logits ( Tensor ) – 每个事件的未归一化对数概率

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

expand(batch_shape, _instance=None)

has_rsample = True

property logits:  Tensor 

property probs:  Tensor 

support = Simplex()

property temperature:  Tensor 

StudentT 分布


class torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)

基类:Distribution

创建一个由自由度 df、均值 loc 和尺度参数 scale 参数化的学生t分布。


示例:

>>> m = StudentT(torch.tensor([2.0]))
>>> m.sample()  # Student's t-distributed with degrees of freedom=2
tensor([0.1046])

参数

  • df (float 或 Tensor) – 自由度
  • loc (float 或 Tensor) – 分布的平均值
  • scale (float 或 Tensor) – 分布的尺度参数

arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)


 

has_rsample = True



log_prob(value)

property mean: Tensor

property mode: Tensor

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = Real()

property variance: Tensor

TransformedDistribution


class torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)

(说明:根据核心翻译原则第1条"代码保护",所有代码块保持原内容不处理,因此上述Python类定义未作翻译,完整保留原始格式和内容)


Bases: Distribution

Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:

X ~ BaseDistributionY = f(X) ~ TransformedDistribution(BaseDistribution, f)log p(Y) = log p(X) + log |det (dX/dY)|

Note that the .event_shape of a TransformedDistribution is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.

An example for the usage of TransformedDistribution would be:

# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)

For more examples, please look at the implementations of Gumbel, HalfCauchy, HalfNormal, LogNormal, Pareto, Weibull, RelaxedBernoulli and RelaxedOneHotCategorical


arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {}

cdf(value)

通过反转变换并计算基础分布的得分来计算累积分布函数。


expand(batch_shape, _instance=None)

property has_rsample:  bool 

icdf(value)

通过变换计算逆累积分布函数,并得出基础分布的评分值。


log_prob(value)

通过逆变换计算样本得分,利用基础分布的得分和对数绝对雅可比行列式进行评分。


rsample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的重参数化样本,或者当分布参数为批处理时,生成形状为 sample_shape 的批量重参数化样本。首先从基础分布中采样,然后对列表中的每个变换应用 transform() 方法。

返回类型:Tensor


sample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的样本,如果分布参数是批处理的,则生成形状为 sample_shape 的样本批次。首先从基础分布中采样,然后对列表中的每个变换应用 transform() 方法。


property support 

返回类型:_DependentProperty


统一性


class torch.distributions.uniform.Uniform(low, high, validate_args=None)

基础分布:Distribution

生成在半开区间 [low, high) 内均匀分布的随机样本。


示例:

>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
>>> m.sample()  # uniformly distributed in the range [0.0, 5.0)
tensor([2.3418])

参数

  • low (float 或 Tensor) - 下限值(包含)
  • high (float 或 Tensor) - 上限值(不包含)

arg_constraints = {'high': Dependent(), 'low': Dependent()}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean: Tensor

property mode: Tensor

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


property stddev:  Tensor 

property support 

返回类型:_DependentProperty


property variance:  Tensor 

冯·米塞斯


class torch.distributions.von_mises.VonMises(loc, concentration, validate_args=None)

基类:Distribution

圆形冯·米塞斯分布。

该实现采用极坐标系。locvalue参数可以是任意实数(以便进行无约束优化),但会被解释为对2π取模的角度值。


示例:

>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # von Mises distributed with loc=1 and concentration=1
tensor([1.9777])

参数

  • loc (torch.Tensor) - 以弧度表示的角度值
  • concentration (torch.Tensor) - 集中度参数

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}expand(batch_shape, _instance=None)


has_rsample = False log_prob(value)

property mean:  Tensor 

提供的平均值为循环平均值。


property mode:  Tensor 

sample(sample_shape=torch.Size([]))

The sampling algorithm for the von Mises distribution is based on the following paper: D.J. Best and N.I. Fisher, “Efficient simulation of the von Mises distribution.” Applied Statistics (1979): 152-157.

Sampling is always done in double precision internally to avoid a hang in _rejection_sample() for small values of the concentration, which starts to happen for single precision around 1e-4 (see issue #88443).


support = Real()

property variance:  Tensor 

提供的方差为圆形方差。


weibull


class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)

基类:TransformedDistribution

从双参数威布尔分布中采样的实现。


示例

>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Weibull distribution with scale=1, concentration=1
tensor([0.4784])

参数

  • scale (float 或 Tensor) - 分布的尺度参数(lambda)。
  • concentration (float 或 Tensor) - 分布的集中度参数(k/shape)。

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)

property mean:  Tensor 

property mode:  Tensor 


support = GreaterThan(lower_bound=0.0)

property variance:  Tensor 

wishart


class torch.distributions.wishart.Wishart(df, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)

基类:ExponentialFamily

创建一个由对称正定矩阵Σ\SigmaΣ或其Cholesky分解Σ=LL⊤\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\topΣ=LL⊤参数化的Wishart分布。


示例

>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
>>> m.sample()  # Wishart distributed with mean=`df * I` and >># variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j

参数

  • df (float 或 Tensor) – 实值参数,需大于(方阵的维度)- 1
  • covariance_matrix (Tensor) – 正定协方差矩阵
  • precision_matrix (Tensor) – 正定精度矩阵
  • scale_tril (Tensor) – 协方差矩阵的下三角因子,其对角线元素为正

注意:只能指定 covariance_matrixprecision_matrixscale_tril 中的一个。

使用 scale_tril 会更高效:所有内部计算都基于 scale_tril。如果传入的是 covariance_matrixprecision_matrix,则仅用于通过 Cholesky 分解计算对应的下三角矩阵。

torch.distributions.LKJCholesky 是一种受限的 Wishart 分布。[1]

参考文献

[1] Wang, Z., Wu, Y. 和 Chu, H., 2018、关于 LKJ 分布与受限 Wishart 分布的等价性。

[2] Sawyer, S., 2007、Wishart 分布与逆 Wishart 采样。

[3] Anderson, T. W., 2003、多元统计分析导论(第 3 版)。

[4] Odell, P. L. 和 Feiveson, A. H., 1966、生成样本协方差矩阵的数值方法。JASA, 61(313):199-203。

[5] Ku, Y.-C. 和 Bloomfield, P., 2010、在 OX 中生成具有分数自由度的随机 Wishart 矩阵。


arg_constraints = {'covariance_matrix': PositiveDefinite(), 'df': GreaterThan(lower_bound=0), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}

property covariance_matrix:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property precision_matrix:  Tensor 

rsample(sample_shape=torch.Size([]), max_try_correction=None)

Warning: In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
Several tries to correct singular samples are performed by default, but it may end up returning
singular matrix samples. Singular samples may return -inf values in .log_prob().
In those cases, the user should validate the samples and either fix the value of df or adjust max_try_correction value for argument in .rsample accordingly.

Return type : Tensor


property scale_tril:  Tensor

support = PositiveDefinite()

property variance: Tensor

KL Divergence

`torch.distributions.kl.kl_divergence(p, q)` 

Compute Kullback-Leibler divergence KL(p∥q)KL(p | q)KL(p∥q) between two distributions.

KL(p∥q)=∫p(x)log⁡p(x)q(x) dxKL(p | q) = \int p(x) \log\frac {p(x)} {q(x)} \,dxKL(p∥q)=∫p(x)logq(x)p(x)​dx


Parameters

  • p (Distribution) – A Distribution object.
  • q (Distribution) – A Distribution object.

Returns
A batch of KL divergences of shape batch_shape.

Return type : Tensor

Raises
NotImplementedError – If the distribution types have not been registered via [register_kl()`](https://pytorch.org/docs/stable/data.html#torch.distributions.kl.register_kl “torch.distributions.kl.register_kl”).

KL divergence is currently implemented for the following distribution pairs:* Bernoulli and Bernoulli

  • Bernoulli and Poisson
  • Beta and Beta
  • Beta and ContinuousBernoulli
  • Beta and Exponential
  • Beta and Gamma
  • Beta and Normal
  • Beta and Pareto
  • Beta and Uniform
  • Binomial and Binomial
  • Categorical and Categorical
  • Cauchy and Cauchy
  • ContinuousBernoulli and ContinuousBernoulli
  • ContinuousBernoulli and Exponential
  • ContinuousBernoulli and Normal
  • ContinuousBernoulli and Pareto
  • ContinuousBernoulli and Uniform
  • Dirichlet and Dirichlet
  • Exponential and Beta
  • Exponential and ContinuousBernoulli
  • Exponential and Exponential
  • Exponential and Gamma
  • Exponential and Gumbel
  • Exponential and Normal
  • Exponential and Pareto
  • Exponential and Uniform
  • ExponentialFamily and ExponentialFamily
  • Gamma and Beta
  • Gamma and ContinuousBernoulli
  • Gamma and Exponential
  • Gamma and Gamma
  • Gamma and Gumbel
  • Gamma and Normal
  • Gamma and Pareto
  • Gamma and Uniform
  • Geometric and Geometric
  • Gumbel and Beta
  • Gumbel and ContinuousBernoulli
  • Gumbel and Exponential
  • Gumbel and Gamma
  • Gumbel and Gumbel
  • Gumbel and Normal
  • Gumbel and Pareto
  • Gumbel and Uniform
  • HalfNormal and HalfNormal
  • Independent and Independent
  • Laplace and Beta
  • Laplace and ContinuousBernoulli
  • Laplace and Exponential
  • Laplace and Gamma
  • Laplace and Laplace
  • Laplace and Normal
  • Laplace and Pareto
  • Laplace and Uniform
  • LowRankMultivariateNormal and LowRankMultivariateNormal
  • LowRankMultivariateNormal and MultivariateNormal
  • MultivariateNormal and LowRankMultivariateNormal
  • MultivariateNormal and MultivariateNormal
  • Normal and Beta
  • Normal and ContinuousBernoulli
  • Normal and Exponential
  • Normal and Gamma
  • Normal and Gumbel
  • Normal and Laplace
  • Normal and Normal
  • Normal and Pareto
  • Normal and Uniform
  • OneHotCategorical and OneHotCategorical
  • Pareto and Beta
  • Pareto and ContinuousBernoulli
  • Pareto and Exponential
  • Pareto and Gamma
  • Pareto and Normal
  • Pareto and Pareto
  • Pareto and Uniform
  • Poisson and Bernoulli
  • Poisson and Binomial
  • Poisson and Poisson
  • TransformedDistribution and TransformedDistribution
  • Uniform and Beta
  • Uniform and ContinuousBernoulli
  • Uniform and Exponential
  • Uniform and Gamma
  • Uniform and Gumbel
  • Uniform and Normal
  • Uniform and Pareto
  • Uniform and Uniform

torch.distributions.kl.register_kl(type_p, type_q)

Decorator to register a pairwise function with kl_divergence().
Usage:

@register_kl(Normal, Normal)
def kl_normal_normal(p, q):# insert implementation here

Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a RuntimeWarning is raised. For example to resolve the ambiguous situation:

@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...@register_kl(DerivedP, BaseQ) 
def kl_version2(p, q): ...

you should register a third most-specific implementation, e.g.:

register_kl(DerivedP, DerivedQ)(kl_version1)  # 打破平局

Parameters

  • type_p (type) – A subclass of Distribution.
  • type_q (type) – A subclass of Distribution.

Transforms


class torch.distributions.transforms.AbsTransform(cache_size=0)

Transform via the mapping y=∣x∣y = |x|y=∣x∣.


class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)

Transform via the pointwise affine mapping y=loc+scale×xy = \text{loc} + \text{scale} \times xy=loc+scale×x.


Parameters

  • loc ( Tensor or float) – Location parameter.
  • scale ( Tensor or float) – Scale parameter.
  • event_dim ( int ) – Optional size of event_shape. This should be zerofor univariate random variables, 1 for distributions over vectors, 2 for distributions over matrices, etc.

class torch.distributions.transforms.CatTransform(tseq, dim=0, lengths=None, cache_size=0)

(注:根据核心翻译原则第1条"代码保护"规则,代码块内容保持原样不翻译)


Transform functor that applies a sequence of transforms tseq
component-wise to each submatrix at dim, of length lengths[dim], in a way compatible with torch.cat().


Example:

x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)x = torch.cat([x0, x0], dim=0)t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])t = CatTransform([t0, t0], dim=0, lengths=[20, 20])y = t(x)

class torch.distributions.transforms.ComposeTransform(parts, cache_size=0)

Composes multiple transforms in a chain.
The transforms being composed are responsible for caching.


Parameters

  • parts (list of Transform ) – A list of transforms to compose.
  • cache_size ( int ) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.

class torch.distributions.transforms.CorrCholeskyTransform(cache_size=0)

Transforms an uncontrained real vector xxx with length D∗(D−1)/2D*(D-1)/2D∗(D−1)/2 into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
triangular matrix with positive diagonals and unit Euclidean norm for each row.
The transform is processed as follows:

1、First we convert x into a lower triangular matrix in row order.
2、For each row XiX_iXi​ of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform XiX_iXi​ into a unit Euclidean length vector using the following steps:

  • Scales into the interval (−1,1)(-1, 1)(−1,1) domain: ri=tanh⁡(Xi)r_i = \tanh(X_i)ri​=tanh(Xi​).
  • Transforms into an unsigned domain: zi=ri2z_i = r_i^2zi​=ri2​.
  • Applies si=StickBreakingTransform(zi)s_i = StickBreakingTransform(z_i)si​=StickBreakingTransform(zi​).
  • Transforms back into signed domain: yi=sign(ri)∗siy_i = sign(r_i) * \sqrt{s_i}yi​=sign(ri​)∗si​​.

class torch.distributions.transforms.CumulativeDistributionTransform(distribution, cache_size=0)

Transform via the cumulative distribution function of a probability distribution.


Parameters

  • distribution (Distribution) – Distribution whose cumulative distribution function to use for the transformation.

Example:

# 从多元正态分布构建高斯Copula
base_dist = MultivariateNormal(loc=torch.zeros(2), scale_tril=LKJCholesky(2).sample(), )
transform = CumulativeDistributionTransform(Normal(0, 1))
copula = TransformedDistribution(base_dist, [transform])

class torch.distributions.transforms.ExpTransform(cache_size=0)

Transform via the mapping y=exp⁡(x)y = \exp(x)y=exp(x).


class torch.distributions.transforms.IndependentTransform(base_transform, reinterpreted_batch_ndims, cache_size=0)

Wrapper around another transform to treat
reinterpreted_batch_ndims-many extra of the right most dimensions as dependent. This has no effect on the forward or backward transforms, but
does sum out reinterpreted_batch_ndims-many of the rightmost dimensions in log_abs_det_jacobian().


Parameters

  • base_transform ( Transform ) – A base transform.
  • reinterpreted_batch_ndims ( int ) – The number of extra rightmost
    dimensions to treat as dependent.

class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)

Transform from unconstrained matrices to lower-triangular matrices with nonnegative diagonal entries.

This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.


class torch.distributions.transforms.PositiveDefiniteTransform(cache_size=0)

(说明:根据核心翻译原则第1条"代码保护"规则,所有代码块保持原内容不处理)


Transform from unconstrained matrices to positive-definite matrices.


class torch.distributions.transforms.PowerTransform(exponent, cache_size=0)

Transform via the mapping y=xexponenty = x^{\text{exponent}}y=xexponent.


class torch.distributions.transforms.ReshapeTransform(in_shape, out_shape, cache_size=0)

Unit Jacobian transform to reshape the rightmost part of a tensor.

Note that in_shape and out_shape must have the same number of elements, just as for torch.Tensor.reshape().


Parameters

  • in_shape ( torch.Size ) – The input event shape.
  • out_shape ( torch.Size ) – The output event shape.
  • cache_size ( int ) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported. (Default 0.)

class torch.distributions.transforms.SigmoidTransform(cache_size=0)

Transform via the mapping y=11+exp⁡(−x)y = \frac{1}{1 + \exp(-x)}y=1+exp(−x)1​ and x=logit(y)x = \text{logit}(y)x=logit(y).


class torch.distributions.transforms.SoftplusTransform(cache_size=0)

Transform via the mapping Softplus(x)=log⁡(1+exp⁡(x))\text{Softplus}(x) = \log(1 + \exp(x))Softplus(x)=log(1+exp(x)).
The implementation reverts to the linear function when x>20x 20x>20、


class torch.distributions.transforms.TanhTransform(cache_size=0)

Transform via the mapping y=tanh⁡(x)y=tanh⁡(x)y=tanh(x).

It is equivalent to

ComposeTransform([AffineTransform(0.0, 2.0),SigmoidTransform(),AffineTransform(-1.0, 2.0),]
)

However this might not be numerically stable, thus it is recommended to use TanhTransform
instead.

Note that one should use cache_size=1 when it comes to NaN/Inf values.


class torch.distributions.transforms.SoftmaxTransform(cache_size=0)

SoftmaxTransform 是 PyTorch 分布变换类,用于实现 softmax 变换。该变换通常用于将未归一化的 logits 转换为概率分布。cache_size 参数控制变换结果的缓存大小,设置为 0 表示不缓存。


Transform from unconstrained space to the simplex via y=exp⁡(x)y = \exp(x)y=exp(x) then
normalizing.

This is not bijective and cannot be used for HMC. However this acts mostly
coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms.


class torch.distributions.transforms.StackTransform(tseq, dim=0, cache_size=0)

Transform functor that applies a sequence of transforms tseq
component-wise to each submatrix at dim in a way compatible with torch.stack().


Example:

x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)t = StackTransform([ExpTransform(), identity_transform], dim=1)y = t(x)

class torch.distributions.transforms.StickBreakingTransform(cache_size=0)

Transform from unconstrained space to the simplex of one additional
dimension via a stick-breaking process.

This transform arises as an iterated sigmoid transform in a stick-breaking
construction of the Dirichlet distribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses.

This is bijective and appropriate for use in HMC; however it mixes
coordinates together and is less appropriate for optimization.


class torch.distributions.transforms.Transform(cache_size=0)

Abstract class for invertable transformations with computable log
det jacobians. They are primarily used in torch.distributions.TransformedDistribution.

Caching is useful for transforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values
since the autograd graph may be reversed. For example while the following
works with or without caching:

y = t(x)t.log_abs_det_jacobian(x, y).backward()  # x将接收梯度。

However the following will error when caching due to dependency reversal:

y = t(x)z = t.inv(y)grad(z.sum(), [y])  # 报错,因为 z 就是 x

Derived classes should implement one or both of _call() or _inverse(). Derived classes that set bijective=True should also
implement log_abs_det_jacobian().


Parameters

  • cache_size ( int ) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.

Variables

  • domain (Constraint) – The constraint representing valid inputs to this transform.
  • codomain (Constraint) – The constraint representing valid outputs to this transform
    which are inputs to the inverse transform.
  • bijective ([bool]) – Whether this transform is bijective. A transform
    t is bijective iff t.inv(t(x)) == x and t(t.inv(y)) == y for every x in the domain and y in the codomain. Transforms that are not bijective should at least
    maintain the weaker pseudoinverse properties
    t(t.inv(t(x)) == t(x) and t.inv(t(t.inv(y))) == t.inv(y).
  • sign ( int or Tensor ) – For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone
    increasing or decreasing.

property inv: Transform

Returns the inverse Transform of this transform.
This should satisfy t.inv.inv is t.


property sign: int

Returns the sign of the determinant of the Jacobian, if applicable.
In general this only makes sense for bijective transforms.


log_abs_det_jacobian(x, y)

Computes the log det jacobian log |dy/dx| given input and output.



forward_shape(shape)

Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.



inverse_shape(shape)

Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.


Constraints


class torch.distributions.constraints.Constraint

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

Variables

  • is_discrete ([bool]) – Whether constrained space is discrete.
    Defaults to False.
  • event_dim ( int ) – Number of rightmost dimensions that together define an event. The check() method will remove this many dimensions
    when computing validity.


check(value)

Returns a byte tensor of sample_shape + batch_shape indicating
whether each event in value satisfies this constraint.


torch.distributions.constraints.cat

alias of _Cat


torch.distributions.constraints.dependent_property

alias of _DependentProperty


torch.distributions..constraints.greater_than

alias of _GreaterThan


torch.distributions..constraints.greater_than_eq

alias of _GreaterThanEq


torch.distributions..constraints.independent

alias of _IndependentConstraint


torch.distributions..constraints.integer_interval

alias of _IntegerInterval


torch.distributions..constraints.interval

alias of _Interval


torch.distributions..constraints.half_open_interval

alias of _HalfOpenInterval


torch.distributions..constraints.is_dependent(constraint)

Checks if constraint is a _Dependent object.


Parameters

  • constraint – A Constraint object.

Returns
True if constraint can be refined to the type _Dependent, False otherwise.

Return type
bool


Examples

>>> import torch>>> from torch.distributions import Bernoulli>>> from torch.distributions.constraints import is_dependent


>>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True))>>> constraint1 = dist.arg_constraints["probs"]>>> constraint2 = dist.arg_constraints["logits"]

>>> for constraint in [constraint1, constraint2]:if is_dependent(constraint):continue

torch.distributions.constraints.less_than 

alias of _LessThan


torch.distributions..constraints.multinomial 

alias of _Multinomial


torch.distributions..constraints.stack

alias of _Stack


Constraint Registry

PyTorch provides two global ConstraintRegistry objects that link Constraint objects to Transform objects. These objects both input constraints and return transforms, but they have different guarantees on bijectivity.

1、biject_to(constraint) looks up a bijective Transform from constraints.real to the given constraint. The returned transform is guaranteed to have .bijective = True and should implement .log_abs_det_jacobian().
2、transform_to(constraint) looks up a not-necessarily bijective Transform from constraints.real to the given constraint. The returned transform is not guaranteed to implement .log_abs_det_jacobian().

The transform_to() registry is useful for performing unconstrained optimization on constrained parameters of probability distributions, which are indicated by each distribution’s .arg_constraints dict. These transforms often overparameterize a space in order to avoid rotation; they are thus more suitable for coordinate-wise optimization algorithms like Adam:

loc = torch.zeros(100, requires_grad=True)unconstrained = torch.zeros(100, requires_grad=True)scale = transform_to(Normal.arg_constraints["scale"])(unconstrained)loss = -Normal(loc, scale).log_prob(data).sum()

The biject_to() registry is useful for Hamiltonian Monte Carlo, where samples from a probability distribution with constrained .support are propagated in an unconstrained space, and algorithms are typically rotation invariant.:

dist = Exponential(rate)unconstrained = torch.zeros(100, requires_grad=True)sample = biject_to(dist.support)(unconstrained)potential_energy = -dist.log_prob(sample).sum()

Note: An example where transform_to and biject_to differ is constraints.simplex: transform_to(constraints.simplex) returns a SoftmaxTransform that simply exponentiates and normalizes its inputs; this is a cheap and mostly coordinate-wise operation appropriate for algorithms like SVI. In contrast, biject_to(constraints.simplex) returns a StickBreakingTransform that bijects its input down to a one-fewer-dimensional space; this a more expensive less numerically stable transform but is needed for algorithms like HMC.

The biject_to and transform_to objects can be extended by user-defined constraints and transforms using their .register() method either as a function on singleton constraints:

transform_to.register(my_constraint, my_transform)

or as a decorator on parameterized constraints:

@transform_to.register(MyConstraintClass)
def my_factory(constraint):assert isinstance(constraint, MyConstraintClass)return MyTransform(constraint.param1, constraint.param2)

You can create your own registry by creating a new ConstraintRegistry
object.


class torch.distributions.constraint_registry.ConstraintRegistry

Registry to link constraints to transforms.


register(constraint, factory=None)

Registers a Constraint
subclass in this registry. Usage:

@my_registry.register(MyConstraintClass)
def construct_transform(constraint):assert isinstance(constraint, MyConstraint)return MyTransform(constraint.arg_constraints)

参数说明

  • constraint (Constraint的子类) - 可以是Constraint的子类,或是目标类的单例对象。
  • factory (可调用对象) - 一个可调用对象,接收约束对象作为输入并返回一个Transform对象。

2025-08-20(三)

http://www.dtcms.com/a/341132.html

相关文章:

  • 使数组k递增的最少操作次数
  • 路由器的NAT类型
  • 确保测试环境一致性与稳定性 5大策略
  • AI 效应: GPT-6,“用户真正想要的是记忆”
  • 获取本地IP地址、MAC地址写法
  • SQL 中大于小于号的表示方法总结
  • Bitcoin有升值潜力吗
  • 《代码沙盒深度实战:iframe安全隔离与实时双向通信的架构设计与落地策略》
  • 在SQL中使用大模型时间预测模型TimesFM
  • Mybatis执行SQL流程(五)之MapperProxy与MapperMethod
  • zoho crm api 无法修改富文本字段的原因:api 版本太低
  • 23种设计模式——构建器模式(Builder Pattern)详解
  • Spring Boot Controller 使用 @RequestBody + @ModelAttribute 接收请求
  • 车联网(V2X)中万物的重新定义---联网汽车新时代
  • Dubbo 的 Java 项目间调用的完整示例
  • 分析NeRF模型中颜色计算公式中的参数
  • Paraformer实时语音识别中的碎碎念
  • RuntimeError: Dataset scripts are no longer supported, but found wikipedia.py
  • 车辆订单状态管理的优化方案:状态机设计模式
  • 从ioutil到os:Golang在线客服聊天系统文件读取的迁移实践
  • 从零开发Java坦克大战Ⅱ(上) -- 从单机到联机(架构演进与设计模式剖析)
  • 音频大模型学习笔记
  • CS+ for CC编译超慢的问题该如何解决
  • 0-1 背包问题(模板)
  • 汽车ECU实现数据安全存储(机密性保护)的一种方案
  • Ubuntu apt安装nginx
  • 使用Spring Retry组件优雅地实现重试
  • Java 定时任务 - 从基础到高阶使用 - 从 Timer 到 Quart
  • 数据结构 二叉树 二叉树链式结构的实现
  • 数据分析师常用命令