Pytorch在FSDP模型中使用EMA
注:本文章方法只在Pytorch FSDP1的模型上实验过,且切分策略为SHARDED_STATE_DICT
场景。
使用FSDP对模型权重切分后如何使用EMA网上搜了一圈没找到个一个靠谱的办法,干脆自己写一个算了,实现代码如下:
import os
from typing import Dict, List
from collections import defaultdictimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlannerclass ShardEMAModel:def __init__(self, fsdp_model: FSDP, decay: float = 0.999):assert isinstance(fsdp_model, FSDP)self.fsdp_model = fsdp_modelself.decay = decayself.shard_ema_state: Dict[str, List[torch.Tensor]] = defaultdict(list)shard_state = self._get_shard_state()for k, v in shard_state.items():for local_shard in v._local_shards:self.shard_ema_state[k].append(local_shard.tensor.clone())self.num_shard_params = sum([sum([t.numel() for t in v]) for v in self.shard_ema_state.values()])print(f"Shard EMA Model has {self.num_shard_params / 1e6:.3f}M params.")def _get_shard_state(self):with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):shard_state = self.fsdp_model.state_dict()return shard_state@torch.inference_mode()def update(self):"""update EMA Model shard weights"""shard_state = self._get_shard_state()for k, v in shard_state.items():for idx, local_shard in enumerate(v._local_shards):self.shard_ema_state[k][idx].mul_(self.decay).add_(local_shard.tensor, alpha=1 - self.decay)def save_ema_shard_weights(self, save_dir: str):"""save EMA Model shard weights"""with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir, exist_ok=True)shard_state = self.fsdp_model.state_dict()for k, v in shard_state.items():for idx, local_shard in enumerate(v._local_shards):local_shard.tensor = self.shard_ema_state[k][idx]state_dict = {"model": shard_state}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(save_dir),planner=DefaultSavePlanner(),)def save_shard_weights(self, save_dir: str):"""save original FSDP Model shard weights"""with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir, exist_ok=True)shard_state = self.fsdp_model.state_dict()state_dict = {"model": shard_state}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(save_dir),planner=DefaultSavePlanner(),)
使用示例:
# create FSDP Model and EMA Model
fsdp_model = FSDP(...)
ema_model = ShardEMAModel(fsdp_model, decay=0.99)# train fsdp model and optimizer weights
...# update EMA Model shard weights
ema_model.update()# save EMA Model shard weights
ema_model.save_ema_shard_weights("save_path")