pytorch中的FSDP
文章目录
- pytorch中的FSDP
- 目录
- 1. FSDP 是什么?一句话总结
- 2. 为什么需要 FSDP?(DDP 的局限性)
- 3. FSDP 的核心思想:"完全分片"是如何工作的?
- 4. FSDP 的主要优势
- 5. 如何在 PyTorch 中使用 FSDP?(代码示例)
- 6. FSDP vs. DDP vs. DP 对比
- 7. 高级技巧和最佳实践
- 总结
pytorch中的FSDP
可以把它理解为 PyTorch 官方推出的、用于训练超大规模模型的“终极武器”,是 DistributedDataParallel (DDP)
的进阶版和替代方案。
目录
- FSDP 是什么?一句话总结
- 为什么需要 FSDP?(DDP 的局限性)
- FSDP 的核心思想:"完全分片"是如何工作的?
- FSDP 的主要优势
- 如何在 PyTorch 中使用 FSDP?(代码示例)
- FSDP vs. DDP vs. DP 对比
- 高级技巧和最佳实践
1. FSDP 是什么?一句话总结
FSDP (Fully Sharded Data Parallel) 是一种高效的分布式训练技术,它通过将模型参数、梯度和优化器状态“分片”到所有 GPU 上,极大地降低了单个 GPU 的内存消耗,从而能够训练远超单个 GPU 内存容量的超大模型。
2. 为什么需要 FSDP?(DDP 的局限性)
要理解 FSDP,我们先要看看它的前辈们有什么问题。
-
DataParallel
(DP):非常简单,但效率低下。它会将模型复制到每个 GPU,但在主 GPU 上汇总梯度,导致主 GPU 负载和内存开销巨大。基本已被弃用。 -
DistributedDataParallel
(DDP):目前最主流的分布式训练方式。每个 GPU 进程都拥有一个完整的模型副本。在反向传播后,各个 GPU 上的梯度会通过All-Reduce
操作进行通信,确保所有模型的梯度同步,然后各自更新模型。
DDP 的核心局限性在于:
每个 GPU 必须独立存储整个模型的完整副本、完整的梯度和完整的优化器状态。
这导致一个严重的问题:当模型变得非常大时(例如,百亿、千亿参数的模型,如 GPT-3),没有任何一张 GPU 能存得下如此庞大的模型和其相关的优化器状态。
举个例子: 一个 100 亿参数的模型(FP32),光是参数就需要 10B * 4 bytes ≈ 40GB
内存。如果使用 Adam 优化器,优化器状态(动量和方差)还需要额外的 10B * 4 bytes * 2 = 80GB
。再加上梯度(40GB
)和中间激活值,单个 GPU 根本无法承受。
而 FSDP 正是为了解决这个内存瓶颈而诞生的。
3. FSDP 的核心思想:"完全分片"是如何工作的?
FSDP 的名字 “Fully Sharded”(完全分片)已经揭示了它的核心。它将以下三种主要的大块内存分片 (Shard) 到数据并行组中的所有 GPU 上:
- 模型参数 (Model Parameters)
- 梯度 (Gradients)
- 优化器状态 (Optimizer States)
每个 GPU 只负责自己“那一份”参数、梯度和优化器状态的存储和更新。
FSDP 的工作流程(简化版):
想象一个包含多个层的模型,被 FSDP 包装后:
-
初始化:在训练开始前,FSDP 将整个模型的参数、优化器状态分片,每个 GPU 只保留自己负责的一小部分。此时,没有任何一个 GPU 上有完整的模型。
-
前向传播 (Forward Pass):
- 当计算流到达某个 FSDP 包装的模块(比如一个 Transformer Block)时,FSDP 会触发一个
all-gather
通信操作,从所有其他 GPU 上临时拉取它们所持有的该模块的参数分片,从而在当前 GPU 上重构出完整的模块。 - 执行该模块的前向计算。
- 计算一结束,立刻丢弃刚刚拉取来的参数分片,释放内存。
- 当计算流到达某个 FSDP 包装的模块(比如一个 Transformer Block)时,FSDP 会触发一个
-
反向传播 (Backward Pass):
- 流程与前向传播类似。当需要为某个模块计算梯度时,再次通过
all-gather
临时重构出完整的模块参数。 - 计算该模块的梯度。
- 梯度计算完成后,FSDP 不会像 DDP 那样保留完整的梯度,而是立即使用
reduce-scatter
操作。这个操作会一边求和所有 GPU 上的梯度,一边将结果分片,最终每个 GPU 只得到并存储它自己负责的那一部分参数的梯度。
- 流程与前向传播类似。当需要为某个模块计算梯度时,再次通过
-
优化器更新 (Optimizer Step):
- 每个 GPU 使用它本地存储的参数分片、梯度分片和优化器状态分片,独立地更新它自己负责的那一小部分模型参数。
这个过程巧妙地实现了**“用到时才重构,用完即焚”**,确保了在任何时刻,单个 GPU 上的内存占用都非常低。
4. FSDP 的主要优势
- 巨大的内存节省:这是最核心的优势。它使得在商用 GPU(如 A100 80GB)上训练千亿级别参数的模型成为可能。
- 高效的通信:FSDP 将 DDP 在反向传播结束时进行的一次性、大规模的
All-Reduce
通信,分解为贯穿于前向和反向传播过程中的多个小规模all-gather
和reduce-scatter
操作。这使得计算和通信可以重叠 (Overlap),隐藏了通信延迟,提高了训练吞吐量。 - 与 DDP 媲美的速度:对于普通大小的模型,FSDP 的速度通常与 DDP 相当。而对于超大模型,FSDP 是唯一可行的选择。
5. 如何在 PyTorch 中使用 FSDP?(代码示例)
使用 FSDP 的代码改动相对较小,主要集中在模型和优化器的初始化部分。
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp.api import ShardingStrategy
import torch.distributed as dist
import os# 1. 初始化分布式环境
# 通常使用 torchrun 或者 slurm 启动
# export MASTER_ADDR='localhost'
# export MASTER_PORT='12355'
# torchrun --nproc_per_node=2 your_script.py
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)# 定义你的模型
class MyModel(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(10, 2048)self.layer2 = nn.Linear(2048, 2048)self.layer3 = nn.Linear(2048, 5)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 将模型移动到 GPU
model = MyModel().to(local_rank)# 2. 定义 FSDP 包装策略
# size_based_auto_wrap_policy 是一个常用的自动包装策略
# 它会根据模块的大小自动决定哪些模块应该被 FSDP 包装
# functools.partial 用于固定策略的参数
import functools
auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=1_000_000
)# 3. 使用 FSDP 包装模型
# ShardingStrategy.FULL_SHARD 是最彻底的分片策略,最省内存
fsdp_model = FSDP(model,auto_wrap_policy=auto_wrap_policy,sharding_strategy=ShardingStrategy.FULL_SHARD, # 常用策略device_id=torch.cuda.current_device()
)# 4. 创建优化器(!!!必须在模型被 FSDP 包装之后创建!!!)
# FSDP 会将参数展平并分片,所以优化器需要看到的是 FSDP 处理后的参数
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=0.001)# 训练循环(与普通训练几乎一样)
for epoch in range(num_epochs):for data, target in train_loader:data, target = data.to(local_rank), target.to(local_rank)optimizer.zero_grad()output = fsdp_model(data)loss = torch.nn.functional.cross_entropy(output, target)loss.backward()optimizer.step()# 保存模型状态(需要特殊处理)
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
from torch.distributed.fsdp.api import CPUOffload# 获取完整的模型状态字典,需要集中到 rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT, save_policy):cpu_state_dict = fsdp_model.state_dict()if dist.get_rank() == 0:torch.save(cpu_state_dict, "full_model_checkpoint.pt")dist.destroy_process_group()
关键点:
- 自动包装策略 (
auto_wrap_policy
): FSDP 不是把整个模型当成一个大块来包装,而是需要将模型内部的某些模块(如 Transformer 的 Block)独立包装。auto_wrap_policy
可以帮助我们自动完成这个过程,非常方便。 - 优化器创建时机: 必须在模型被
FSDP()
包装之后创建优化器,否则优化器无法正确处理分片后的参数。 - 保存/加载模型: 由于参数是分片的,保存和加载需要使用 FSDP 提供的特定上下文管理器
FSDP.state_dict_type
来正确地收集或分发完整的模型权重。
6. FSDP vs. DDP vs. DP 对比
特性 | DataParallel (DP) | DistributedDataParallel (DDP) | Fully Sharded Data Parallel (FSDP) |
---|---|---|---|
模型存储 | 每个 GPU 复制一份 | 每个 GPU 复制一份 | 分片存储在所有 GPU |
梯度存储 | 汇总到主 GPU | 每个 GPU 复制一份 | 分片存储在所有 GPU |
优化器状态 | 在主 GPU | 每个 GPU 复制一份 | 分片存储在所有 GPU |
单 GPU 内存 | 高(主 GPU 极高) | 高 | 低 |
通信方式 | Scatter -> Gather | All-Reduce | All-Gather -> Reduce-Scatter |
通信/计算重叠 | 否 | 是 | 是(更优) |
适用场景 | 教学、小模型(不推荐) | 主流,绝大多数模型 | 超大模型训练,内存受限场景 |
7. 高级技巧和最佳实践
- 混合精度 (
MixedPrecision
): FSDP 与torch.cuda.amp
或bfloat16
结合使用可以进一步节省内存和加速训练。 - CPU Offloading: 对于极端情况(模型分片后仍然放不进 GPU 内存),FSDP 支持将不活跃的参数分片卸载(offload)到 CPU 内存中。这会牺牲速度,但能让你训练更大的模型。
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
- 激活检查点 (
Activation Checkpointing
): FSDP 解决了参数/梯度/优化器的内存问题,而激活检查点解决了前向传播中中间激活值的内存问题。两者结合是训练巨型模型的标配。 - 分片策略 (
ShardingStrategy
):FULL_SHARD
: 参数、梯度、优化器状态全部分片,最省内存。SHARD_GRAD_OP
: 只分片梯度和优化器状态,模型参数在每个 GPU 上仍是完整的(类似 DDP)。内存消耗介于 DDP 和FULL_SHARD
之间,有时性能更好。HYBRID_SHARD
: 节点内进行FULL_SHARD
,节点间进行 DDP 式的复制。适用于拥有高速节点内连接(如 NVLink)的大型集群。
总结
FSDP 是 PyTorch 生态中用于大规模分布式训练的未来方向。它通过巧妙的“分片”思想,打破了单卡内存的限制,同时通过计算与通信的重叠保持了高效率。如果你需要训练的模型的规模已经超出了 DDP 的能力范围,那么 FSDP 就是你必须掌握的工具。