当前位置: 首页 > news >正文

【大模型训练】加载load_state 中的一些技巧 工厂设计模式

Answer

您询问的是 WorkerState 类的作用,特别是 load_from_json 方法的使用方式 1 。

WorkerState 的作用

WorkerState 是一个用于保存和恢复训练状态的数据类,主要用于checkpoint和训练恢复 2 。它包含三个核心字段:

  • step: 当前训练步数
  • log_history: 训练过程中的指标历史记录列表
  • kv: 键值对字典,用于存储额外的状态信息

load_from_json 方法的实现

load_from_json 是一个类方法,用于从JSON文件中恢复 WorkerState 实例 3 :

@classmethod
def load_from_json(cls, load_dir: str, tag):"""Create an instance from the content of `json_path`."""json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))

该方法:

  1. 根据 load_dirtag 构造JSON文件路径(格式为 worker_state_{tag}.json) 4
  2. 读取JSON文件内容
  3. 将JSON反序列化为字典,并用 ** 解包传递给类构造函数

在 Pipeline 中的使用

BasePipeline 中,WorkerState 用于保存和恢复整个pipeline的训练状态 5 :

初始化时恢复状态

当从checkpoint恢复训练时,pipeline会加载之前保存的状态 6 :

if self.pipeline_config.resume_from_checkpoint:self.resume_from_checkpoint = download_model(self.pipeline_config.resume_from_checkpoint)logger.info(f"resume_from_checkpoint: {self.resume_from_checkpoint}")load_dir = os.path.join(self.resume_from_checkpoint, "pipeline")self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

这里使用 tag="pipeline" 来标识这是pipeline级别的状态文件 7 。

恢复训练指标

加载状态后,pipeline会将历史指标重新记录到tracker中 8 :

def resume_metrics():for metrics in self.state.log_history:self.tracker.log(values=metrics, step=metrics["system/step"])self.resume_futures.append(self.executor.submit(resume_metrics))

保存状态

在checkpoint时,pipeline会保存当前状态 9 :

save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id, "pipeline")
self.state.save_to_json(save_dir=save_dir, tag="pipeline")
self.state.save_rng_state(save_dir=save_dir, tag="pipeline")

配合 RNG 状态保存

WorkerState 还提供了静态方法来保存和加载随机数生成器(RNG)状态,确保训练的可重现性 10 :

  • save_rng_state: 保存Python、NumPy、PyTorch的随机状态
  • load_rng_state: 恢复这些随机状态

这些方法在checkpoint时与 save_to_json/load_from_json 配合使用 11 。

在 Trainer 中的类似用法

McaTrainer 中也有类似的状态恢复逻辑 12 :

if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))

这里使用的是 TrainerState 而非 WorkerState,但模式相同。

Notes

WorkerState 的设计遵循了简单的序列化/反序列化模式,使用JSON格式存储训练状态 13 。tag 参数允许在同一目录下保存多个不同角色的状态文件(如 “pipeline”, “actor_train” 等)。这种设计使得ROLL能够在训练中断后精确恢复到之前的状态,包括训练步数、指标历史和随机数状态 14 。

Wiki pages you might want to explore:

  • API Reference (alibaba/ROLL)
  • Strategy Pattern and Backend Abstraction (alibaba/ROLL)

Citations

File: roll/utils/worker_state.py (L17-17)

WORKER_STATE_NAME = "worker_state_{tag}.json"

File: roll/utils/worker_state.py (L20-30)

@dataclass
class WorkerState:step: int = -1log_history: List[Dict[str, float]] = Nonekv: Dict[str, Union[float, Dict]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []if self.kv is None:self.kv = {}

File: roll/utils/worker_state.py (L32-38)

    def save_to_json(self, save_dir: str, tag):"""Save the content of this instance in JSON format inside `json_path`."""json_path = os.path.join(save_dir, WORKER_STATE_NAME.format(tag=tag))json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"os.makedirs(os.path.dirname(json_path), exist_ok=True)with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)

File: roll/utils/worker_state.py (L40-46)

    @classmethoddef load_from_json(cls, load_dir: str, tag):"""Create an instance from the content of `json_path`."""json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))

File: roll/utils/worker_state.py (L48-77)

    @staticmethoddef save_rng_state(save_dir, tag):# Save RNG state in non-distributed trainingrng_states = {"python": random.getstate(),"numpy": np.random.get_state(),"cpu": torch.random.get_rng_state(),current_platform.device_type: current_platform.random.get_rng_state_all(),}os.makedirs(save_dir, exist_ok=True)torch.save(rng_states, os.path.join(save_dir, f"rng_state_{tag}.pth"))@staticmethoddef load_rng_state(load_dir, tag):# Load RNG states from `checkpoint`if load_dir is None:returnrng_file = os.path.join(load_dir, f"rng_state_{tag}.pth")if not os.path.isfile(rng_file):logger.info(f"Didn't find an RNG file for process {tag}, if you are resuming a training that ""wasn't launched in a distributed fashion, reproducibility is not guaranteed.")returncheckpoint_rng_state = torch.load(rng_file, weights_only=False)random.setstate(checkpoint_rng_state["python"])np.random.set_state(checkpoint_rng_state["numpy"])torch.random.set_rng_state(checkpoint_rng_state["cpu"])current_platform.random.set_rng_state_all(checkpoint_rng_state[current_platform.device_type])

File: roll/pipeline/base_pipeline.py (L32-32)

        self.state = WorkerState()

File: roll/pipeline/base_pipeline.py (L43-48)

        if self.pipeline_config.resume_from_checkpoint:self.resume_from_checkpoint = download_model(self.pipeline_config.resume_from_checkpoint)logger.info(f"resume_from_checkpoint: {self.resume_from_checkpoint}")load_dir = os.path.join(self.resume_from_checkpoint, "pipeline")self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

File: roll/pipeline/base_pipeline.py (L50-54)

            def resume_metrics():for metrics in self.state.log_history:self.tracker.log(values=metrics, step=metrics["system/step"])self.resume_futures.append(self.executor.submit(resume_metrics))

File: roll/pipeline/base_pipeline.py (L88-92)

            pipeline_save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id)save_dir = os.path.join(self.pipeline_config.output_dir, "pipeline", ckpt_id, "pipeline")self.state.save_to_json(save_dir=save_dir, tag="pipeline")self.state.save_rng_state(save_dir=save_dir, tag="pipeline")self.checkpoint_manager.upload(ckpt_id=ckpt_id, local_state_path=pipeline_save_dir)

File: mcore_adapter/src/mcore_adapter/trainer/trainer.py (L664-667)

        if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))

Answer

您询问的是为什么 load_from_json 要设计成类方法(@classmethod)而不是普通方法

类方法的必要性

load_from_json 必须是类方法,因为它的作用是创建并返回一个新的类实例 1 。这是一个典型的工厂方法模式

关键区别

普通实例方法:

  • 需要先有一个实例才能调用: instance.method()
  • 第一个参数是 self,指向已存在的实例
  • 用于操作已有对象的状态

类方法:

  • 直接通过类调用: WorkerState.load_from_json(load_dir, tag)
  • 第一个参数是 cls,指向类本身
  • 用于创建新实例

实际使用场景

BasePipeline 中,从checkpoint恢复时需要从零创建一个 WorkerState 实例 2 :

self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

此时还没有现成的 WorkerState 实例,所以不可能用普通方法调用。类方法允许直接通过类名调用,并返回一个新构造的实例

与构造函数的对比

您可能会问:为什么不直接用 __init__ 构造函数?

类方法提供了更灵活的构造方式:

  1. 语义清晰: load_from_json 明确表达"从JSON文件加载"的意图,比直接调用 __init__ 更易读
  2. 多种构造方式: 可以有多个类方法提供不同的构造路径(如 from_dict, from_json, from_yaml),而 __init__ 只能有一个
  3. 预处理逻辑: 类方法可以在调用 __init__ 前进行文件读取、数据解析等操作 3

相同模式在其他地方的应用

这个模式在ROLL代码库中被广泛使用:

  • MegatronArguments.from_json_file 4 : 从JSON文件创建训练参数实例
  • PretrainedConfig.from_json_file 5 : 从JSON文件创建模型配置实例
  • McaModelConfig.from_pretrained 6 : 从预训练路径创建配置实例

所有这些都使用 @classmethod 装饰器,因为它们都是替代构造函数(alternative constructors)

如果改成普通方法会怎样

假设改成普通方法:

def load_from_json(self, load_dir: str, tag):# ... 读取JSON ...# 但这里无法返回新实例,只能修改self

这样就无法实现从文件创建新实例的功能,因为:

  1. 调用前必须先有实例: state = WorkerState(); state.load_from_json(...)
  2. 无法返回新对象,只能修改现有对象的属性
  3. 违反了"从文件加载"的语义

Notes

类方法是Python中实现工厂模式和替代构造函数的标准做法@classmethod 装饰器让方法接收类本身作为第一个参数(cls),从而可以调用 cls(...) 来创建新实例 1 。这种设计在需要多种对象创建方式时非常有用,是面向对象设计的最佳实践

Wiki pages you might want to explore:

  • Core Architecture (alibaba/ROLL)

Citations

File: roll/utils/worker_state.py (L40-46)

    @classmethoddef load_from_json(cls, load_dir: str, tag):"""Create an instance from the content of `json_path`."""json_path = os.path.join(load_dir, WORKER_STATE_NAME.format(tag=tag))with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))

File: roll/pipeline/base_pipeline.py (L48-48)

            self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline")

File: mcore_adapter/src/mcore_adapter/training_args.py (L317-321)

    @classmethoddef from_json_file(cls, json_file_path) -> "MegatronArguments":with open(json_file_path, "r", encoding="utf-8") as reader:text = reader.read()return cls(**json.loads(text))

File: mcore_adapter/src/mcore_adapter/models/model_config.py (L69-73)

    @classmethoddef from_json_file(cls, json_file_path):with open(json_file_path, "r", encoding="utf-8") as reader:text = reader.read()return cls(**json.loads(text))

File: mcore_adapter/src/mcore_adapter/models/model_config.py (L124-125)

    @classmethoddef from_pretrained(cls, model_name_or_path: str, args: Optional["TrainingArguments"] = None):
http://www.dtcms.com/a/586722.html

相关文章:

  • 地产公司做网站维护写代码么电子商务网站建设规划报告
  • 网站建设中图片尺寸深圳做网站
  • 蒙自网站建设wordpress通知搜索引擎收录
  • 玩转Rust高级应用 如何理解 Rust 实现免疫数据竞争的关键是Send 和 Sync 这两个 trait
  • 强人工智能软件、人的奖惩系统(强化学习系统)
  • 青县网站制作局域网建设网站
  • 优化网站排名公司邢台中高风险地区查询
  • php 网站响应时间课程网站设计建设
  • 小喇叭开始广播——以太网与WiFi协议
  • 环境感知模型
  • 网站设计制作教程天眼查河南建设网站公司
  • 怎么制作网站详细教程视频什么什么网站
  • 东莞网站开发营销哈尔滨建站的系统
  • html5 wap 网站模板西安网站建设制作公司
  • 第四十四篇|语言教育的结构可计算性:大阪观光商务日本语学院的语义建模实践
  • 自动驾驶-判断前后左右
  • 网站开发亿码酷流量网站推广页面 英语
  • vps网站空间沧州兼职网站建设
  • 网站权重如何速度增加福州小程序开发平台
  • FAML 完全入门指南:新一代动态配置语言
  • srcType instanceof Class 及泛型 vs 普通类
  • 上海网站制作公司有哪些网站建设服务包含内容
  • 章丘做网站优化网站优化无限关键词设置
  • Java线程通信:多线程程序中的高效协作!
  • 一个彩票网站建设徐州seo公司
  • 自己动手建立网站3个人网站 不用备案
  • 建设网站的建设费用包括星星wordpress模板
  • 湖北做网站的网站建设分金手指专业二
  • 飞牛NAS中安装Navidrome音乐文件中文标签乱码问题解决、安装FntermX终端
  • 合肥公司网站建设wordpress 下一页