当前位置: 首页 > news >正文

TIME - MoE 模型代码 4——Time-MoE-main/run_eval.py

源码:https://github.com/Time-MoE/Time-MoE

这段代码是一个用于评估 Time-MoE 模型性能的脚本,它支持分布式环境下的模型评估,通过计算 MSE 和 MAE 等指标来衡量模型在时间序列预测任务上的表现。代码的核心功能包括:模型加载、数据处理、预测生成以及多节点分布式评估。


关键模块与组件

1. 环境初始化与分布式设置

def setup_nccl(rank, world_size, master_addr='127.0.0.1', master_port=9899):dist.init_process_group("nccl", init_method='tcp://{}:{}'.format(master_addr, master_port), rank=rank,world_size=world_size)
  • 该函数使用 NCCL 后端初始化 PyTorch 分布式训练环境
  • 通过 TCP 协议连接主节点,实现多 GPU 或多节点通信
  • rank 表示当前进程 ID,world_size 表示总进程数

2. 评估指标体系

class SumEvalMetric:def __init__(self, name, init_val: float = 0.0):self.name = nameself.value = init_valdef push(self, preds, labels, **kwargs):self.value += self._calculate(preds, labels, **kwargs)class MSEMetric(SumEvalMetric):def _calculate(self, preds, labels, **kwargs):return torch.sum((preds - labels) ** 2)class MAEMetric(SumEvalMetric):def _calculate(self, preds, labels, **kwargs):return torch.sum(torch.abs(preds - labels))
  • 采用面向对象设计,基类 SumEvalMetric 定义了评估指标的基本结构
  • MSEMetric 和 MAEMetric 继承自基类,分别实现均方误差和平均绝对误差计算
  • push 方法用于累积每个批次的评估结果

3. 模型加载与预测模块

class TimeMoE:def __init__(self, model_path, device, context_length, prediction_length, **kwargs):try:from time_moe.models.modeling_time_moe import TimeMoeForPredictionmodel = TimeMoeForPrediction.from_pretrained(model_path,device_map=device,torch_dtype='auto',)except:model = AutoModelForCausalLM.from_pretrained(model_path,device_map=device,torch_dtype='auto',trust_remote_code=True,)def predict(self, batch):outputs = model.generate(inputs=batch['inputs'].to(device).to(model.dtype),max_new_tokens=prediction_length,)preds = outputs[:, -prediction_length:]labels = batch['labels'].to(device)return preds, labels
  • 支持两种模型加载方式:原生 Time-MoE 模型或通过 transformers 库加载的通用模型
  • 使用from_pretrained方法加载预训练权重,并自动处理设备映射和数据类型转换
  • predict 方法通过 generate 接口生成预测结果,提取最后 prediction_length 个时间步作为预测值

4. 数据处理流程

if args.data.endswith('.csv'):dataset = BenchmarkEvalDataset(args.data,context_length=context_length,prediction_length=prediction_length,)
else:dataset = GeneralEvalDataset(args.data,context_length=context_length,prediction_length=prediction_length,)if torch.cuda.is_available() and dist.is_initialized():sampler = DistributedSampler(dataset=dataset, shuffle=False)
else:sampler = Nonetest_dl = DataLoader(dataset=dataset,batch_size=batch_size,sampler=sampler,shuffle=False,num_workers=2,prefetch_factor=2,
)
  • 根据数据文件格式选择不同的数据集类
  • 支持分布式环境下的数据采样,确保各进程处理不同的数据分片
  • 数据加载器配置了多线程数据读取和预取,优化数据处理性能

5. 评估主流程

acc_count = 0
with torch.no_grad():for idx, batch in enumerate(tqdm(test_dl)):preds, labels = model.predict(batch)for metric in metric_list:metric.push(preds, labels)acc_count += count_num_tensor_elements(preds)# 分布式环境下的结果聚合
if is_dist:stat_tensor = torch.tensor(metric_tensors).to(model.device)gathered_results = [torch.zeros_like(stat_tensor) for _ in range(world_size)]dist.all_gather(gathered_results, stat_tensor)all_stat = torch.stack(gathered_results, dim=0).sum(dim=0)
else:all_stat = metric_tensors# 计算最终评估结果
count = all_stat[-1]
for i, metric in enumerate(metric_list):val = all_stat[i] / countitem[metric.name] = float(val.cpu().numpy())
  • 使用 torch.no_grad () 上下文管理器关闭梯度计算,提高推理速度
  • 遍历数据集,累积每个批次的预测结果和评估指标
  • 在分布式环境下,使用 all_gather 操作收集所有进程的统计数据
  • 最终在主进程上计算并打印全局评估结果

高级特性解析

1. 自适应上下文长度设置

if args.context_length is None:if args.prediction_length == 96:args.context_length = 512elif args.prediction_length == 192:args.context_length = 1024elif args.prediction_length == 336:args.context_length = 2048elif args.prediction_length == 720:args.context_length = 3072else:args.context_length = args.prediction_length * 4
  • 根据预测长度自动设置合适的上下文长度
  • 预测长度越长,所需的历史上下文信息也越多
  • 默认使用预测长度的 4 倍作为上下文长度

2. 分布式结果聚合

stat_tensor = torch.tensor(metric_tensors).to(model.device)
gathered_results = [torch.zeros_like(stat_tensor) for _ in range(world_size)]
dist.all_gather(gathered_results, stat_tensor)
all_stat = torch.stack(gathered_results, dim=0).sum(dim=0)
  • 使用 all_gather 操作将所有进程的统计数据收集到每个进程中
  • 对收集到的结果进行求和,得到全局统计数据
  • 确保最终评估结果基于所有数据分片

3. 动态设备映射与数据类型处理

model = TimeMoeForPrediction.from_pretrained(model_path,device_map=device,torch_dtype='auto',
)
  • device_map 参数自动处理模型在多 GPU 间的分布
  • torch_dtype='auto' 根据硬件自动选择最优数据类型
  • 支持混合精度推理,提高计算效率

使用方法与参数说明

parser = argparse.ArgumentParser('TimeMoE Evaluate')
parser.add_argument('--model', '-m', type=str, default='Maple728/TimeMoE-50M', help='Model path')
parser.add_argument('--data', '-d', type=str, help='Benchmark data path')
parser.add_argument('--batch_size', '-b', type=int, default=32, help='Batch size of evaluation')
parser.add_argument('--context_length', '-c', type=int, help='Context length')
parser.add_argument('--prediction_length', '-p', type=int, default=96, help='Prediction length')
  • --model:指定要评估的模型路径
  • --data:指定评估数据集路径
  • --batch_size:评估时的批次大小
  • --context_length:输入的历史上下文长度
  • --prediction_length:要预测的未来时间步长度

总结

这段代码实现了一个完整的 Time-MoE 模型评估系统,具有以下特点:

  1. 支持分布式环境下的高效评估
  2. 提供了 MSE 和 MAE 等常用评估指标
  3. 能够处理不同格式的时间序列数据
  4. 自动适应不同的预测长度和上下文长度
  5. 优化了模型加载和推理过程,支持混合精度计算

这个评估脚本可以帮助研究人员和工程师准确衡量 Time-MoE 模型在各种时间序列预测任务上的性能表现。

我们的实验

print问题

1.输出内容的来源与原因

(1)模型初始化信息
logging.info(f'>>> Model dtype: {model.dtype}; Attention:{model.config._attn_implementation}')
  • 位置TimeMoE类的__init__方法。
  • 原因:记录模型的数据类型(如float32)和注意力机制实现方式(如eagerflash_attention_2)。
(2)进度条
for idx, batch in enumerate(tqdm(test_dl)):...
  • 位置evaluate函数的主循环。
  • 原因:使用tqdm库显示评估进度,便于用户了解当前完成情况。
(3)各进程的局部评估结果
print(f'{rank} - {ret_metric}')
  • 位置evaluate函数的结果聚合前。
  • 原因:打印每个进程计算的局部 MSE 和 MAE 指标(分布式环境下每个 GPU 计算一部分数据)。
(4)汇总后的全局评估结果
logging.info(item)
  • 位置evaluate函数中rank == 0的条件分支。
  • 原因:在主进程中汇总所有进程的结果,输出最终的全局评估指标。

3. 输出功能的实现代码

(1)logging 模块的配置与使用
# 隐式配置(未在代码中显示,但transformers库默认配置了logging)
import logging# 使用示例
logging.info(...)  # 输出INFO级别的日志
  • 特点:日志格式通常包含时间戳、日志级别和消息内容。
(2)print 语句的使用
for idx, batch in enumerate(tqdm(test_dl)):...
  • 位置evaluate函数中,在分布式结果聚合前。
  • 原因:记录模型的数据类型(如float32)和注意力机制实现方式(如eagerflash_attention_2)。
(3)tqdm 进度条
from tqdm import tqdmfor batch in tqdm(test_dl):  # 包装数据加载器,显示进度...
  • 功能:动态显示评估进度(如100%|██████████| 100/100 [00:30<00:00])。

4. 分布式环境下的输出规则

  • 局部结果:每个进程(GPU)都会打印自己计算的指标(通过print)。
  • 全局结果:仅主进程(rank == 0)汇总并输出最终指标(通过logging)。
  • 示例
    # 分布式环境下(如4卡)可能的输出:
    0 - {'mse': tensor(0.0123, device='cuda:0'), 'mae': tensor(0.0987, device='cuda:0')}
    1 - {'mse': tensor(0.0119, device='cuda:1'), 'mae': tensor(0.0976, device='cuda:1')}
    2 - {'mse': tensor(0.0121, device='cuda:2'), 'mae': tensor(0.0981, device='cuda:2')}
    3 - {'mse': tensor(0.0125, device='cuda:3'), 'mae': tensor(0.0993, device='cuda:3')}# 主进程汇总后的结果:
    INFO: {'model': ..., 'mse': 0.0122, 'mae': 0.0984}
    

总结

  • 输出内容:模型信息、评估进度、各进程局部指标、全局汇总指标。
  • 输出原因:监控评估过程、验证模型性能、支持分布式环境调试。
  • 实现代码
    • logging模块(记录模型配置和最终结果)。
    • print语句(打印各进程局部结果)。
    • tqdm库(显示进度条)。

相关文章:

  • 图形化编程革命:iVX携手AI 原生开发范式
  • MNIST 数据并行 Data Parallel - DP
  • 【目标检测系列】YOLOV1解读
  • Go语言实现豆瓣电影Top250爬虫
  • 掌握 void 类型在函数返回值中的应用
  • MIT 6.S081 2020 Lab3 page tables 个人全流程
  • 添加文字标签
  • Docker使用ClickHouse | ClickHouse 配置用户名密码 | ClickHouse 可视化 | windows系统 | 镜像
  • 类型别名与接口的对比与选择
  • Javascript:数组和函数
  • 【心海资源】【最新话费盗u】【未测】提币对方官方波场+没有任何加密+无后门+前端VUE
  • 专业课复习笔记 5
  • Three.js + React 实战系列 - 职业经历区实现解析 Experience 组件✨(互动动作 + 3D 角色 + 点击切换动画)
  • 【星海随笔】信息安全法律法规概述
  • 单片机调用printf概率性跑飞解决方法
  • 大疆卓驭嵌入式面经及参考答案
  • 论文阅读与写作:《从探索到突破:解密科研和论文写作的思维密码》
  • 《从零构建一个简易的IOC容器,理解Spring的核心思想》
  • GitHub打开缓慢甚至失败的解决办法
  • 【QT】UDP通讯本地调试
  • 内塔尼亚胡:以军将在未来几天“全力进入”加沙
  • 香港根据《维护国家安全条例》订立附属法例
  • 男子退机票被收票价90%的手续费,律师:虽然合规,但显失公平
  • 上海现有超12.3万名注册护士,本科及以上学历占一半
  • 河北邯郸一酒店婚宴发生火灾:众人惊险逃生,酒店未买保险
  • “拼好假”的年轻人,今年有哪些旅游新玩法?