【大模型训练】加载load_state 中的一些技巧 工厂设计模式
Answer
您询问的是 WorkerState 类的作用,特别是 load_from_json 方法的使用方式 1 。
WorkerState 的作用
WorkerState 是一个用于保存和恢复训练状态的数据类,主要用于checkpoint和训练恢复 2 。它包含三个核心字段:
step: 当前训练步数log_history: 训练过程中的指标历史记录列表kv: 键值对字典,用于存储额外的状态信息
load_from_json 方法的实现
load_from_json 是一个类方法,用于从JSON文件中恢复 WorkerState 实例 3 :
@classmethod
def load_from_json(cls, load_dir: str, tag):"""Create an instance from the content of `json_path`."""json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))
该方法:
- 根据
load_dir和tag构造JSON文件路径(格式为worker_state_{tag}.json) 4 - 读取JSON文件内容
- 将JSON反序列化为字典,并用
**解包传递给类构造函数
在 Pipeline 中的使用
在 BasePipeline 中,WorkerState 用于保存和恢复整个pipeline的训练状态 5 :
初始化时恢复状态
当从checkpoint恢复训练时,pipeline会加载之前保存的状态 6 :
if self.pipeline_config.resume_from_checkpoint:self.resume_from_checkpoint = download_model(self.pipeline_config.resume_from_checkpoint)logger.info(f"resume_from_checkpoint: {self.resume_from_checkpoint}")load_dir = os.path.join(self.resume_from_checkpoint, "pipeline")self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")
这里使用 tag="pipeline" 来标识这是pipeline级别的状态文件 7 。
恢复训练指标
加载状态后,pipeline会将历史指标重新记录到tracker中 8 :
def resume_metrics():for metrics in self.state.log_history:self.tracker.log(values=metrics, step=metrics["system/step"])self.resume_futures.append(self.executor.submit(resume_metrics))
保存状态
在checkpoint时,pipeline会保存当前状态 9 :
save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id, "pipeline")
self.state.save_to_json(save_dir=save_dir, tag="pipeline")
self.state.save_rng_state(save_dir=save_dir, tag="pipeline")
配合 RNG 状态保存
WorkerState 还提供了静态方法来保存和加载随机数生成器(RNG)状态,确保训练的可重现性 10 :
save_rng_state: 保存Python、NumPy、PyTorch的随机状态load_rng_state: 恢复这些随机状态
这些方法在checkpoint时与 save_to_json/load_from_json 配合使用 11 。
在 Trainer 中的类似用法
在 McaTrainer 中也有类似的状态恢复逻辑 12 :
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
这里使用的是 TrainerState 而非 WorkerState,但模式相同。
Notes
WorkerState 的设计遵循了简单的序列化/反序列化模式,使用JSON格式存储训练状态 13 。tag 参数允许在同一目录下保存多个不同角色的状态文件(如 “pipeline”, “actor_train” 等)。这种设计使得ROLL能够在训练中断后精确恢复到之前的状态,包括训练步数、指标历史和随机数状态 14 。
Wiki pages you might want to explore:
- API Reference (alibaba/ROLL)
- Strategy Pattern and Backend Abstraction (alibaba/ROLL)
Citations
File: roll/utils/worker_state.py (L17-17)
WORKER_STATE_NAME = "worker_state_{tag}.json"
File: roll/utils/worker_state.py (L20-30)
@dataclass
class WorkerState:step: int = -1log_history: List[Dict[str, float]] = Nonekv: Dict[str, Union[float, Dict]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []if self.kv is None:self.kv = {}
File: roll/utils/worker_state.py (L32-38)
def save_to_json(self, save_dir: str, tag):"""Save the content of this instance in JSON format inside `json_path`."""json_path = os.path.join(save_dir, WORKER_STATE_NAME.format(tag=tag))json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"os.makedirs(os.path.dirname(json_path), exist_ok=True)with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)
File: roll/utils/worker_state.py (L40-46)
@classmethoddef load_from_json(cls, load_dir: str, tag):"""Create an instance from the content of `json_path`."""json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))
File: roll/utils/worker_state.py (L48-77)
@staticmethoddef save_rng_state(save_dir, tag):# Save RNG state in non-distributed trainingrng_states = {"python": random.getstate(),"numpy": np.random.get_state(),"cpu": torch.random.get_rng_state(),current_platform.device_type: current_platform.random.get_rng_state_all(),}os.makedirs(save_dir, exist_ok=True)torch.save(rng_states, os.path.join(save_dir, f"rng_state_{tag}.pth"))@staticmethoddef load_rng_state(load_dir, tag):# Load RNG states from `checkpoint`if load_dir is None:returnrng_file = os.path.join(load_dir, f"rng_state_{tag}.pth")if not os.path.isfile(rng_file):logger.info(f"Didn't find an RNG file for process {tag}, if you are resuming a training that ""wasn't launched in a distributed fashion, reproducibility is not guaranteed.")returncheckpoint_rng_state = torch.load(rng_file, weights_only=False)random.setstate(checkpoint_rng_state["python"])np.random.set_state(checkpoint_rng_state["numpy"])torch.random.set_rng_state(checkpoint_rng_state["cpu"])current_platform.random.set_rng_state_all(checkpoint_rng_state[current_platform.device_type])
File: roll/pipeline/base_pipeline.py (L32-32)
self.state = WorkerState()
File: roll/pipeline/base_pipeline.py (L43-48)
if self.pipeline_config.resume_from_checkpoint:self.resume_from_checkpoint = download_model(self.pipeline_config.resume_from_checkpoint)logger.info(f"resume_from_checkpoint: {self.resume_from_checkpoint}")load_dir = os.path.join(self.resume_from_checkpoint, "pipeline")self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")
File: roll/pipeline/base_pipeline.py (L50-54)
def resume_metrics():for metrics in self.state.log_history:self.tracker.log(values=metrics, step=metrics["system/step"])self.resume_futures.append(self.executor.submit(resume_metrics))
File: roll/pipeline/base_pipeline.py (L88-92)
pipeline_save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id)save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id, "pipeline")self.state.save_to_json(save_dir=save_dir, tag="pipeline")self.state.save_rng_state(save_dir=save_dir, tag="pipeline")self.checkpoint_manager.upload(ckpt_id=ckpt_id, local_state_path=pipeline_save_dir)
File: mcore_adapter/src/mcore_adapter/trainer/trainer.py (L664-667)
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
Answer
您询问的是为什么 load_from_json 要设计成类方法(@classmethod)而不是普通方法。
类方法的必要性
load_from_json 必须是类方法,因为它的作用是创建并返回一个新的类实例 1 。这是一个典型的工厂方法模式。
关键区别
普通实例方法:
- 需要先有一个实例才能调用:
instance.method() - 第一个参数是
self,指向已存在的实例 - 用于操作已有对象的状态
类方法:
- 直接通过类调用:
WorkerState.load_from_json(load_dir, tag) - 第一个参数是
cls,指向类本身 - 用于创建新实例
实际使用场景
在 BasePipeline 中,从checkpoint恢复时需要从零创建一个 WorkerState 实例 2 :
self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")
此时还没有现成的 WorkerState 实例,所以不可能用普通方法调用。类方法允许直接通过类名调用,并返回一个新构造的实例。
与构造函数的对比
您可能会问:为什么不直接用 __init__ 构造函数?
类方法提供了更灵活的构造方式:
- 语义清晰:
load_from_json明确表达"从JSON文件加载"的意图,比直接调用__init__更易读 - 多种构造方式: 可以有多个类方法提供不同的构造路径(如
from_dict,from_json,from_yaml),而__init__只能有一个 - 预处理逻辑: 类方法可以在调用
__init__前进行文件读取、数据解析等操作 3
相同模式在其他地方的应用
这个模式在ROLL代码库中被广泛使用:
MegatronArguments.from_json_file4 : 从JSON文件创建训练参数实例PretrainedConfig.from_json_file5 : 从JSON文件创建模型配置实例McaModelConfig.from_pretrained6 : 从预训练路径创建配置实例
所有这些都使用 @classmethod 装饰器,因为它们都是替代构造函数(alternative constructors)。
如果改成普通方法会怎样
假设改成普通方法:
def load_from_json(self, load_dir: str, tag):# ... 读取JSON ...# 但这里无法返回新实例,只能修改self
这样就无法实现从文件创建新实例的功能,因为:
- 调用前必须先有实例:
state = WorkerState(); state.load_from_json(...) - 无法返回新对象,只能修改现有对象的属性
- 违反了"从文件加载"的语义
Notes
类方法是Python中实现工厂模式和替代构造函数的标准做法。@classmethod 装饰器让方法接收类本身作为第一个参数(cls),从而可以调用 cls(...) 来创建新实例 1 。这种设计在需要多种对象创建方式时非常有用,是面向对象设计的最佳实践。
Wiki pages you might want to explore:
- Core Architecture (alibaba/ROLL)
Citations
File: roll/utils/worker_state.py (L40-46)
@classmethoddef load_from_json(cls, load_dir: str, tag):"""Create an instance from the content of `json_path`."""json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))
File: roll/pipeline/base_pipeline.py (L48-48)
self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")
File: mcore_adapter/src/mcore_adapter/training_args.py (L317-321)
@classmethoddef from_json_file(cls, json_file_path) -> "MegatronArguments":with open(json_file_path, "r", encoding="utf-8") as reader:text = reader.read()return cls(**json.loads(text))
File: mcore_adapter/src/mcore_adapter/models/model_config.py (L69-73)
@classmethoddef from_json_file(cls, json_file_path):with open(json_file_path, "r", encoding="utf-8") as reader:text = reader.read()return cls(**json.loads(text))
File: mcore_adapter/src/mcore_adapter/models/model_config.py (L124-125)
@classmethoddef from_pretrained(cls, model_name_or_path: str, args: Optional["TrainingArguments"] = None):
