当前位置: 首页 > news >正文

Pytorch FSDP权重分片保存与合并

注:本文章方法只适用Pytorch FSDP1的模型,且切分策略为FULL_STATE_DICT场景。

在使用FSDP训练模型时,为了节省显存通常会把模型权重也进行切分,在保存权重时为了加速保存通常每个进程各自保存自己持有的部分权重,避免先汇聚到主进程再保存浪费大量时间的问题。保存成分片权重后,如果需要推理则还需要将分片权重进行合并。下面提供了保存分片权重以及将分片权重合并的代码示例,代码主要参考accelerate官方源码。

import osimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
import torch.distributed.checkpoint.format_utils as dist_cp_format_utilsdef save_fsdp_model(model: FSDP, fsdp_ckpt_path: str):# refer accelerate/utils/fsdp_utils.py:save_fsdp_modelwith FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):os.makedirs(fsdp_ckpt_path, exist_ok=True)state_dict = {"model": model.state_dict()}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(fsdp_ckpt_path),planner=DefaultSavePlanner(),)def merge_fsdp_weights(fsdp_ckpt_path: str, save_path: str):# refer accelerate/utils/fsdp_utils.py:merge_fsdp_weightsstate_dict = {}dist_cp_format_utils._load_state_dict(state_dict,storage_reader=dist_cp.FileSystemReader(fsdp_ckpt_path),planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),no_dist=True,)# To handle if state is a dict like {model: {...}}if len(state_dict.keys()) == 1:state_dict = state_dict[list(state_dict)[0]]torch.save(state_dict, save_path)
http://www.dtcms.com/a/329464.html

相关文章:

  • Node.js简介及安装
  • 人工到智能:塑料袋拆垛的自动化革命 —— 迁移科技的实践与创新
  • Node.js浏览器引擎+Python大脑的智能爬虫系统
  • Vue3从入门到精通: 3.5 Vue3与TypeScript集成深度解析
  • 热门手机机型重启速度对比
  • PCB题目基础练习2
  • 从“字”到“画”:基于Elasticsearch Serverless 的多模态商品搜索实践
  • aave v3 存款利息的计算方式
  • 《红黑树的原理与C++实现:详解平衡艺术的高效构建与操作》
  • 无人设备遥控器之编码技术篇
  • 【剑指offer】搜索算法
  • 力扣(跳跃游戏I/II)
  • c++26新功能—inplace_vector
  • 达梦数据库常见漏洞及处理方案
  • PostgreSQL——索引
  • TensorFlow实现回归分析详解
  • npm install 的作用
  • HTTP 请求转发与重定向详解及其应用(含 Java 示例)
  • Windows平台RTSP播放器选型与低延迟全解析及技术实践
  • 迅为RK3568开发板模型推理测试实战deeplabv3语义分割
  • Java基础 8.13
  • 【Flowable】工作流网关 控制流程的流向
  • 深度学习——03 神经网络(3)-网络优化方法
  • 门店销售机器人的智能升级:具身智能模型带来的变革
  • Mac安装ant
  • Linux性能分析教程:top, htop, iotop命令使用详解 (服务器慢/卡顿排查)
  • 电脑如何安装win10专业版_电脑用u盘安装win10专业版教程
  • GO学习记录四——读取excel完成数据库建表
  • 10.反射获取静态类的属性 C#例子 WPF例子
  • 5.0.9.1 C# wpf通过WindowsFormsHost嵌入windows media player(AxInterop.WMPLib)