当前位置: 首页 > news >正文

Pytorch在FSDP模型中使用EMA

注:本文章方法只在Pytorch FSDP1的模型上实验过,且切分策略为SHARDED_STATE_DICT场景。

使用FSDP对模型权重切分后如何使用EMA网上搜了一圈没找到个一个靠谱的办法,干脆自己写一个算了,实现代码如下:

import os
from typing import Dict, List
from collections import defaultdictimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlannerclass ShardEMAModel:def __init__(self, fsdp_model: FSDP, decay: float = 0.999):assert isinstance(fsdp_model, FSDP)self.fsdp_model = fsdp_modelself.decay = decayself.shard_ema_state: Dict[str, List[torch.Tensor]] = defaultdict(list)shard_state = self._get_shard_state()for k, v in shard_state.items():for local_shard in v._local_shards:self.shard_ema_state[k].append(local_shard.tensor.clone())self.num_shard_params = sum([sum([t.numel() for t in v]) for v in self.shard_ema_state.values()])print(f"Shard EMA Model has {self.num_shard_params / 1e6:.3f}M params.")def _get_shard_state(self):with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):shard_state = self.fsdp_model.state_dict()return shard_state@torch.inference_mode()def update(self):"""update EMA Model shard weights"""shard_state = self._get_shard_state()for k, v in shard_state.items():for idx, local_shard in enumerate(v._local_shards):self.shard_ema_state[k][idx].mul_(self.decay).add_(local_shard.tensor, alpha=1 - self.decay)def save_ema_shard_weights(self, save_dir: str):"""save EMA Model shard weights"""with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir, exist_ok=True)shard_state = self.fsdp_model.state_dict()for k, v in shard_state.items():for idx, local_shard in enumerate(v._local_shards):local_shard.tensor = self.shard_ema_state[k][idx]state_dict = {"model": shard_state}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(save_dir),planner=DefaultSavePlanner(),)def save_shard_weights(self, save_dir: str):"""save original FSDP Model shard weights"""with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir, exist_ok=True)shard_state = self.fsdp_model.state_dict()state_dict = {"model": shard_state}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(save_dir),planner=DefaultSavePlanner(),)

使用示例:

# create FSDP Model and EMA Model
fsdp_model = FSDP(...)
ema_model = ShardEMAModel(fsdp_model, decay=0.99)# train fsdp model and optimizer weights
...# update EMA Model shard weights
ema_model.update()# save EMA Model shard weights
ema_model.save_ema_shard_weights("save_path")
http://www.dtcms.com/a/331500.html

相关文章:

  • Leetcode_1780.判断一个数字是否可以表示成三的幂的和
  • UE5 C++ 删除文件
  • BotCash:GPT-5发布观察 工程优化的进步,还是技术突破的瓶颈?
  • Spring Boot + Redis Cluster 测试
  • 回流(Reflow)与重绘(Repaint):浏览器渲染性能优化核心
  • 演员念真主演《镇恶追凶》辽宁杀青
  • 数字电路上的通讯速度是越快越好还是越慢越好?
  • 【二分图】染色问题
  • 企业智脑UMI AIGC SaaS:解锁AI时代全场景生产力,中小微企业转型利器
  • Linux学习-多任务(进程)
  • **隐私沙盒:发散创新之光**随着互联网技术的飞速发展,数据安全和隐私保护逐渐成为人们关注的焦点。隐私沙盒作为一种新兴
  • Ping32 与绿盾再对比:Ping32 以创新与适配领跑数据安全​
  • 机器学习内容总结
  • 机器学习-基础入门:从概念到核心方法论
  • MySQL进阶——优化、日志
  • 第4节课:多模态大模型的核心能力(多模态大模型基础教程)
  • 疏老师-python训练营-Day45Tensorboard使用介绍
  • StarRocks优化统计分析
  • 好用的开源数据可视化设计工具LIGHT CHASER
  • Java List 集合详解(ArrayList、LinkedList、Vector)
  • pyecharts可视化图表-pie:从入门到精通
  • 适用工业分选和工业应用的高光谱相机有哪些?什么品牌比较好?
  • 这个就是哈希冲突
  • AI出题人给出的Java后端面经(十四)(日更)
  • 智慧养老解决方案:破解“最后一公里”服务难题
  • 【98页PPT】智慧方案某著名企业汽配行业ERP整体解决方案(附下载方式)
  • BGP笔记及实验
  • 网络层协议——IP
  • 2025年机器视觉与信号处理国际会议(MVSP 2025)
  • 72小时到24小时:台风“杨柳”过后,有鹿机器人如何为园区按下“加速键”?