当前位置: 首页 > news >正文

pytorch中的FSDP

文章目录

  • pytorch中的FSDP
      • 目录
      • 1. FSDP 是什么?一句话总结
      • 2. 为什么需要 FSDP?(DDP 的局限性)
      • 3. FSDP 的核心思想:"完全分片"是如何工作的?
      • 4. FSDP 的主要优势
      • 5. 如何在 PyTorch 中使用 FSDP?(代码示例)
      • 6. FSDP vs. DDP vs. DP 对比
      • 7. 高级技巧和最佳实践
      • 总结

pytorch中的FSDP

可以把它理解为 PyTorch 官方推出的、用于训练超大规模模型的“终极武器”,是 DistributedDataParallel (DDP) 的进阶版和替代方案。


目录

  1. FSDP 是什么?一句话总结
  2. 为什么需要 FSDP?(DDP 的局限性)
  3. FSDP 的核心思想:"完全分片"是如何工作的?
  4. FSDP 的主要优势
  5. 如何在 PyTorch 中使用 FSDP?(代码示例)
  6. FSDP vs. DDP vs. DP 对比
  7. 高级技巧和最佳实践

1. FSDP 是什么?一句话总结

FSDP (Fully Sharded Data Parallel) 是一种高效的分布式训练技术,它通过将模型参数、梯度和优化器状态“分片”到所有 GPU 上,极大地降低了单个 GPU 的内存消耗,从而能够训练远超单个 GPU 内存容量的超大模型。


2. 为什么需要 FSDP?(DDP 的局限性)

要理解 FSDP,我们先要看看它的前辈们有什么问题。

  • DataParallel (DP):非常简单,但效率低下。它会将模型复制到每个 GPU,但在主 GPU 上汇总梯度,导致主 GPU 负载和内存开销巨大。基本已被弃用

  • DistributedDataParallel (DDP):目前最主流的分布式训练方式。每个 GPU 进程都拥有一个完整的模型副本。在反向传播后,各个 GPU 上的梯度会通过 All-Reduce 操作进行通信,确保所有模型的梯度同步,然后各自更新模型。

DDP 的核心局限性在于:
每个 GPU 必须独立存储整个模型的完整副本、完整的梯度和完整的优化器状态。

这导致一个严重的问题:当模型变得非常大时(例如,百亿、千亿参数的模型,如 GPT-3),没有任何一张 GPU 能存得下如此庞大的模型和其相关的优化器状态。

举个例子: 一个 100 亿参数的模型(FP32),光是参数就需要 10B * 4 bytes ≈ 40GB 内存。如果使用 Adam 优化器,优化器状态(动量和方差)还需要额外的 10B * 4 bytes * 2 = 80GB。再加上梯度(40GB)和中间激活值,单个 GPU 根本无法承受。

而 FSDP 正是为了解决这个内存瓶颈而诞生的。


3. FSDP 的核心思想:"完全分片"是如何工作的?

FSDP 的名字 “Fully Sharded”(完全分片)已经揭示了它的核心。它将以下三种主要的大块内存分片 (Shard) 到数据并行组中的所有 GPU 上:

  1. 模型参数 (Model Parameters)
  2. 梯度 (Gradients)
  3. 优化器状态 (Optimizer States)

每个 GPU 只负责自己“那一份”参数、梯度和优化器状态的存储和更新。

FSDP 的工作流程(简化版):

想象一个包含多个层的模型,被 FSDP 包装后:

  1. 初始化:在训练开始前,FSDP 将整个模型的参数、优化器状态分片,每个 GPU 只保留自己负责的一小部分。此时,没有任何一个 GPU 上有完整的模型。

  2. 前向传播 (Forward Pass)

    • 当计算流到达某个 FSDP 包装的模块(比如一个 Transformer Block)时,FSDP 会触发一个 all-gather 通信操作,从所有其他 GPU 上临时拉取它们所持有的该模块的参数分片,从而在当前 GPU 上重构出完整的模块
    • 执行该模块的前向计算。
    • 计算一结束,立刻丢弃刚刚拉取来的参数分片,释放内存。
  3. 反向传播 (Backward Pass)

    • 流程与前向传播类似。当需要为某个模块计算梯度时,再次通过 all-gather 临时重构出完整的模块参数。
    • 计算该模块的梯度。
    • 梯度计算完成后,FSDP 不会像 DDP 那样保留完整的梯度,而是立即使用 reduce-scatter 操作。这个操作会一边求和所有 GPU 上的梯度,一边将结果分片,最终每个 GPU 只得到并存储它自己负责的那一部分参数的梯度
  4. 优化器更新 (Optimizer Step)

    • 每个 GPU 使用它本地存储的参数分片梯度分片优化器状态分片,独立地更新它自己负责的那一小部分模型参数。

这个过程巧妙地实现了**“用到时才重构,用完即焚”**,确保了在任何时刻,单个 GPU 上的内存占用都非常低。


4. FSDP 的主要优势

  1. 巨大的内存节省:这是最核心的优势。它使得在商用 GPU(如 A100 80GB)上训练千亿级别参数的模型成为可能。
  2. 高效的通信:FSDP 将 DDP 在反向传播结束时进行的一次性、大规模的 All-Reduce 通信,分解为贯穿于前向和反向传播过程中的多个小规模 all-gatherreduce-scatter 操作。这使得计算和通信可以重叠 (Overlap),隐藏了通信延迟,提高了训练吞吐量。
  3. 与 DDP 媲美的速度:对于普通大小的模型,FSDP 的速度通常与 DDP 相当。而对于超大模型,FSDP 是唯一可行的选择。

5. 如何在 PyTorch 中使用 FSDP?(代码示例)

使用 FSDP 的代码改动相对较小,主要集中在模型和优化器的初始化部分。

import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp.api import ShardingStrategy
import torch.distributed as dist
import os# 1. 初始化分布式环境
# 通常使用 torchrun 或者 slurm 启动
# export MASTER_ADDR='localhost'
# export MASTER_PORT='12355'
# torchrun --nproc_per_node=2 your_script.py
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)# 定义你的模型
class MyModel(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(10, 2048)self.layer2 = nn.Linear(2048, 2048)self.layer3 = nn.Linear(2048, 5)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 将模型移动到 GPU
model = MyModel().to(local_rank)# 2. 定义 FSDP 包装策略
# size_based_auto_wrap_policy 是一个常用的自动包装策略
# 它会根据模块的大小自动决定哪些模块应该被 FSDP 包装
# functools.partial 用于固定策略的参数
import functools
auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=1_000_000
)# 3. 使用 FSDP 包装模型
# ShardingStrategy.FULL_SHARD 是最彻底的分片策略,最省内存
fsdp_model = FSDP(model,auto_wrap_policy=auto_wrap_policy,sharding_strategy=ShardingStrategy.FULL_SHARD, # 常用策略device_id=torch.cuda.current_device()
)# 4. 创建优化器(!!!必须在模型被 FSDP 包装之后创建!!!)
# FSDP 会将参数展平并分片,所以优化器需要看到的是 FSDP 处理后的参数
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=0.001)# 训练循环(与普通训练几乎一样)
for epoch in range(num_epochs):for data, target in train_loader:data, target = data.to(local_rank), target.to(local_rank)optimizer.zero_grad()output = fsdp_model(data)loss = torch.nn.functional.cross_entropy(output, target)loss.backward()optimizer.step()# 保存模型状态(需要特殊处理)
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
from torch.distributed.fsdp.api import CPUOffload# 获取完整的模型状态字典,需要集中到 rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT, save_policy):cpu_state_dict = fsdp_model.state_dict()if dist.get_rank() == 0:torch.save(cpu_state_dict, "full_model_checkpoint.pt")dist.destroy_process_group()

关键点:

  • 自动包装策略 (auto_wrap_policy): FSDP 不是把整个模型当成一个大块来包装,而是需要将模型内部的某些模块(如 Transformer 的 Block)独立包装。auto_wrap_policy 可以帮助我们自动完成这个过程,非常方便。
  • 优化器创建时机: 必须在模型被 FSDP() 包装之后创建优化器,否则优化器无法正确处理分片后的参数。
  • 保存/加载模型: 由于参数是分片的,保存和加载需要使用 FSDP 提供的特定上下文管理器 FSDP.state_dict_type 来正确地收集或分发完整的模型权重。

6. FSDP vs. DDP vs. DP 对比

特性DataParallel (DP)DistributedDataParallel (DDP)Fully Sharded Data Parallel (FSDP)
模型存储每个 GPU 复制一份每个 GPU 复制一份分片存储在所有 GPU
梯度存储汇总到主 GPU每个 GPU 复制一份分片存储在所有 GPU
优化器状态在主 GPU每个 GPU 复制一份分片存储在所有 GPU
单 GPU 内存高(主 GPU 极高)
通信方式Scatter -> GatherAll-ReduceAll-Gather -> Reduce-Scatter
通信/计算重叠是(更优)
适用场景教学、小模型(不推荐)主流,绝大多数模型超大模型训练,内存受限场景

7. 高级技巧和最佳实践

  • 混合精度 (MixedPrecision): FSDP 与 torch.cuda.ampbfloat16 结合使用可以进一步节省内存和加速训练。
  • CPU Offloading: 对于极端情况(模型分片后仍然放不进 GPU 内存),FSDP 支持将不活跃的参数分片卸载(offload)到 CPU 内存中。这会牺牲速度,但能让你训练更大的模型。
    fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
    
  • 激活检查点 (Activation Checkpointing): FSDP 解决了参数/梯度/优化器的内存问题,而激活检查点解决了前向传播中中间激活值的内存问题。两者结合是训练巨型模型的标配。
  • 分片策略 (ShardingStrategy):
    • FULL_SHARD: 参数、梯度、优化器状态全部分片,最省内存。
    • SHARD_GRAD_OP: 只分片梯度和优化器状态,模型参数在每个 GPU 上仍是完整的(类似 DDP)。内存消耗介于 DDP 和 FULL_SHARD 之间,有时性能更好。
    • HYBRID_SHARD: 节点内进行 FULL_SHARD,节点间进行 DDP 式的复制。适用于拥有高速节点内连接(如 NVLink)的大型集群。

总结

FSDP 是 PyTorch 生态中用于大规模分布式训练的未来方向。它通过巧妙的“分片”思想,打破了单卡内存的限制,同时通过计算与通信的重叠保持了高效率。如果你需要训练的模型的规模已经超出了 DDP 的能力范围,那么 FSDP 就是你必须掌握的工具。


文章转载自:

http://P07ApcJr.fqmbt.cn
http://OOvACm3p.fqmbt.cn
http://eLQ9KU3U.fqmbt.cn
http://xVPtAjwm.fqmbt.cn
http://4fR7AXA4.fqmbt.cn
http://ZIFAVwl4.fqmbt.cn
http://37UX82jZ.fqmbt.cn
http://eMAlMV8H.fqmbt.cn
http://zD3YTawg.fqmbt.cn
http://XuXKc9nh.fqmbt.cn
http://yfumG18E.fqmbt.cn
http://IK9juWa0.fqmbt.cn
http://kdM3RlSf.fqmbt.cn
http://ekcJOfcb.fqmbt.cn
http://Rr7Mvzkg.fqmbt.cn
http://eYmWrDHF.fqmbt.cn
http://LgqOJoPo.fqmbt.cn
http://92ODOugO.fqmbt.cn
http://Z2OGwpK0.fqmbt.cn
http://6MYhUAqb.fqmbt.cn
http://z8VzwszU.fqmbt.cn
http://eLDxxj8Y.fqmbt.cn
http://Tj6xBhNN.fqmbt.cn
http://9ZwRM09P.fqmbt.cn
http://BB6yypN6.fqmbt.cn
http://u2UlGFqd.fqmbt.cn
http://uC3EdOaU.fqmbt.cn
http://o2OtKrq6.fqmbt.cn
http://Oi8qKNSg.fqmbt.cn
http://K0aBA1rD.fqmbt.cn
http://www.dtcms.com/a/388628.html

相关文章:

  • 贪心算法与材料切割问题详解
  • 2. 结构体
  • MySQL 核心操作:多表联合查询与数据库备份恢复
  • vue3学习日记(十四):两大API选型指南
  • 微信支付回调成功通知到本地
  • 量化交易 - Simple Regression 简单线性回归(机器学习)
  • Kubernetes控制器详解:从Deployment到CronJob
  • python 架构技术50
  • 第九周文件上传
  • MCP大白话理解
  • 【Qt】QJsonValue存储 int64 类型的大整数时,数值出现莫名其妙的变化
  • 【C语言】冒泡排序算法解析与实现
  • [GESP202309 三级] 进制判断
  • 【C++】const和static的用法
  • 箭头函数{}规则,以及隐式返回
  • brain.js构建训练神经网络
  • 开学季高效学习与知识管理技术
  • C++STL与字符串探秘
  • 【面试题】- 使用CompletableFuture实现多线程统计策略工厂模式
  • 打工人日报#20250917
  • LeetCode:12.最小覆盖字串
  • 【C++】 深入理解C++虚函数表与对象析构机制
  • C++ 中 ->和 . 操作符的区别
  • SQL CTE (Common Table Expression) 详解
  • 解决windows更新之后亮度条消失无法调节的问题
  • FPGA学习篇——Verilog学习译码器的实现
  • JavaScript Promise 终极指南 解决回调地狱的异步神器 99% 开发者都在用
  • AI智能体开发实战:从提示工程转向上下文工程的完整指南
  • jtag协议处理流程
  • 【LeetCode 每日一题】2749. 得到整数零需要执行的最少操作数