TIME - MoE 模型代码 4——Time-MoE-main/run_eval.py
源码:https://github.com/Time-MoE/Time-MoE
这段代码是一个用于评估 Time-MoE 模型性能的脚本,它支持分布式环境下的模型评估,通过计算 MSE 和 MAE 等指标来衡量模型在时间序列预测任务上的表现。代码的核心功能包括:模型加载、数据处理、预测生成以及多节点分布式评估。
关键模块与组件
1. 环境初始化与分布式设置
def setup_nccl(rank, world_size, master_addr='127.0.0.1', master_port=9899):dist.init_process_group("nccl", init_method='tcp://{}:{}'.format(master_addr, master_port), rank=rank,world_size=world_size)
- 该函数使用 NCCL 后端初始化 PyTorch 分布式训练环境
- 通过 TCP 协议连接主节点,实现多 GPU 或多节点通信
- rank 表示当前进程 ID,world_size 表示总进程数
2. 评估指标体系
class SumEvalMetric:def __init__(self, name, init_val: float = 0.0):self.name = nameself.value = init_valdef push(self, preds, labels, **kwargs):self.value += self._calculate(preds, labels, **kwargs)class MSEMetric(SumEvalMetric):def _calculate(self, preds, labels, **kwargs):return torch.sum((preds - labels) ** 2)class MAEMetric(SumEvalMetric):def _calculate(self, preds, labels, **kwargs):return torch.sum(torch.abs(preds - labels))
- 采用面向对象设计,基类 SumEvalMetric 定义了评估指标的基本结构
- MSEMetric 和 MAEMetric 继承自基类,分别实现均方误差和平均绝对误差计算
- push 方法用于累积每个批次的评估结果
3. 模型加载与预测模块
class TimeMoE:def __init__(self, model_path, device, context_length, prediction_length, **kwargs):try:from time_moe.models.modeling_time_moe import TimeMoeForPredictionmodel = TimeMoeForPrediction.from_pretrained(model_path,device_map=device,torch_dtype='auto',)except:model = AutoModelForCausalLM.from_pretrained(model_path,device_map=device,torch_dtype='auto',trust_remote_code=True,)def predict(self, batch):outputs = model.generate(inputs=batch['inputs'].to(device).to(model.dtype),max_new_tokens=prediction_length,)preds = outputs[:, -prediction_length:]labels = batch['labels'].to(device)return preds, labels
- 支持两种模型加载方式:原生 Time-MoE 模型或通过 transformers 库加载的通用模型
- 使用
from_pretrained
方法加载预训练权重,并自动处理设备映射和数据类型转换 - predict 方法通过 generate 接口生成预测结果,提取最后 prediction_length 个时间步作为预测值
4. 数据处理流程
if args.data.endswith('.csv'):dataset = BenchmarkEvalDataset(args.data,context_length=context_length,prediction_length=prediction_length,)
else:dataset = GeneralEvalDataset(args.data,context_length=context_length,prediction_length=prediction_length,)if torch.cuda.is_available() and dist.is_initialized():sampler = DistributedSampler(dataset=dataset, shuffle=False)
else:sampler = Nonetest_dl = DataLoader(dataset=dataset,batch_size=batch_size,sampler=sampler,shuffle=False,num_workers=2,prefetch_factor=2,
)
- 根据数据文件格式选择不同的数据集类
- 支持分布式环境下的数据采样,确保各进程处理不同的数据分片
- 数据加载器配置了多线程数据读取和预取,优化数据处理性能
5. 评估主流程
acc_count = 0
with torch.no_grad():for idx, batch in enumerate(tqdm(test_dl)):preds, labels = model.predict(batch)for metric in metric_list:metric.push(preds, labels)acc_count += count_num_tensor_elements(preds)# 分布式环境下的结果聚合
if is_dist:stat_tensor = torch.tensor(metric_tensors).to(model.device)gathered_results = [torch.zeros_like(stat_tensor) for _ in range(world_size)]dist.all_gather(gathered_results, stat_tensor)all_stat = torch.stack(gathered_results, dim=0).sum(dim=0)
else:all_stat = metric_tensors# 计算最终评估结果
count = all_stat[-1]
for i, metric in enumerate(metric_list):val = all_stat[i] / countitem[metric.name] = float(val.cpu().numpy())
- 使用 torch.no_grad () 上下文管理器关闭梯度计算,提高推理速度
- 遍历数据集,累积每个批次的预测结果和评估指标
- 在分布式环境下,使用 all_gather 操作收集所有进程的统计数据
- 最终在主进程上计算并打印全局评估结果
高级特性解析
1. 自适应上下文长度设置
if args.context_length is None:if args.prediction_length == 96:args.context_length = 512elif args.prediction_length == 192:args.context_length = 1024elif args.prediction_length == 336:args.context_length = 2048elif args.prediction_length == 720:args.context_length = 3072else:args.context_length = args.prediction_length * 4
- 根据预测长度自动设置合适的上下文长度
- 预测长度越长,所需的历史上下文信息也越多
- 默认使用预测长度的 4 倍作为上下文长度
2. 分布式结果聚合
stat_tensor = torch.tensor(metric_tensors).to(model.device)
gathered_results = [torch.zeros_like(stat_tensor) for _ in range(world_size)]
dist.all_gather(gathered_results, stat_tensor)
all_stat = torch.stack(gathered_results, dim=0).sum(dim=0)
- 使用 all_gather 操作将所有进程的统计数据收集到每个进程中
- 对收集到的结果进行求和,得到全局统计数据
- 确保最终评估结果基于所有数据分片
3. 动态设备映射与数据类型处理
model = TimeMoeForPrediction.from_pretrained(model_path,device_map=device,torch_dtype='auto',
)
- device_map 参数自动处理模型在多 GPU 间的分布
- torch_dtype='auto' 根据硬件自动选择最优数据类型
- 支持混合精度推理,提高计算效率
使用方法与参数说明
parser = argparse.ArgumentParser('TimeMoE Evaluate')
parser.add_argument('--model', '-m', type=str, default='Maple728/TimeMoE-50M', help='Model path')
parser.add_argument('--data', '-d', type=str, help='Benchmark data path')
parser.add_argument('--batch_size', '-b', type=int, default=32, help='Batch size of evaluation')
parser.add_argument('--context_length', '-c', type=int, help='Context length')
parser.add_argument('--prediction_length', '-p', type=int, default=96, help='Prediction length')
--model
:指定要评估的模型路径--data
:指定评估数据集路径--batch_size
:评估时的批次大小--context_length
:输入的历史上下文长度--prediction_length
:要预测的未来时间步长度
总结
这段代码实现了一个完整的 Time-MoE 模型评估系统,具有以下特点:
- 支持分布式环境下的高效评估
- 提供了 MSE 和 MAE 等常用评估指标
- 能够处理不同格式的时间序列数据
- 自动适应不同的预测长度和上下文长度
- 优化了模型加载和推理过程,支持混合精度计算
这个评估脚本可以帮助研究人员和工程师准确衡量 Time-MoE 模型在各种时间序列预测任务上的性能表现。
我们的实验
print问题
1.输出内容的来源与原因
(1)模型初始化信息
logging.info(f'>>> Model dtype: {model.dtype}; Attention:{model.config._attn_implementation}')
- 位置:
TimeMoE
类的__init__
方法。 - 原因:记录模型的数据类型(如
float32
)和注意力机制实现方式(如eager
或flash_attention_2
)。
(2)进度条
for idx, batch in enumerate(tqdm(test_dl)):...
- 位置:
evaluate
函数的主循环。 - 原因:使用
tqdm
库显示评估进度,便于用户了解当前完成情况。
(3)各进程的局部评估结果
print(f'{rank} - {ret_metric}')
- 位置:
evaluate
函数的结果聚合前。 - 原因:打印每个进程计算的局部 MSE 和 MAE 指标(分布式环境下每个 GPU 计算一部分数据)。
(4)汇总后的全局评估结果
logging.info(item)
- 位置:
evaluate
函数中rank == 0
的条件分支。 - 原因:在主进程中汇总所有进程的结果,输出最终的全局评估指标。
3. 输出功能的实现代码
(1)logging 模块的配置与使用
# 隐式配置(未在代码中显示,但transformers库默认配置了logging)
import logging# 使用示例
logging.info(...) # 输出INFO级别的日志
- 特点:日志格式通常包含时间戳、日志级别和消息内容。
(2)print 语句的使用
for idx, batch in enumerate(tqdm(test_dl)):...
- 位置:
evaluate
函数中,在分布式结果聚合前。 - 原因:记录模型的数据类型(如
float32
)和注意力机制实现方式(如eager
或flash_attention_2
)。
(3)tqdm 进度条
from tqdm import tqdmfor batch in tqdm(test_dl): # 包装数据加载器,显示进度...
- 功能:动态显示评估进度(如
100%|██████████| 100/100 [00:30<00:00]
)。
4. 分布式环境下的输出规则
- 局部结果:每个进程(GPU)都会打印自己计算的指标(通过
print
)。 - 全局结果:仅主进程(
rank == 0
)汇总并输出最终指标(通过logging
)。 - 示例:
# 分布式环境下(如4卡)可能的输出: 0 - {'mse': tensor(0.0123, device='cuda:0'), 'mae': tensor(0.0987, device='cuda:0')} 1 - {'mse': tensor(0.0119, device='cuda:1'), 'mae': tensor(0.0976, device='cuda:1')} 2 - {'mse': tensor(0.0121, device='cuda:2'), 'mae': tensor(0.0981, device='cuda:2')} 3 - {'mse': tensor(0.0125, device='cuda:3'), 'mae': tensor(0.0993, device='cuda:3')}# 主进程汇总后的结果: INFO: {'model': ..., 'mse': 0.0122, 'mae': 0.0984}
总结
- 输出内容:模型信息、评估进度、各进程局部指标、全局汇总指标。
- 输出原因:监控评估过程、验证模型性能、支持分布式环境调试。
- 实现代码:
logging
模块(记录模型配置和最终结果)。print
语句(打印各进程局部结果)。tqdm
库(显示进度条)。