当前位置: 首页 > news >正文

TIME - MoE 模型代码 3.3——Time-MoE-main/time_moe/datasets/time_moe_window_dataset.py

源码:https://github.com/Time-MoE/Time-MoE

这段代码实现了两个用于时间序列数据处理的窗口化数据集类,主要用于将长序列切割成固定长度的子序列,为模型训练提供合适的输入格式。


1. 核心类:TimeMoEWindowDataset

1.1 功能概述

将长时间序列转换为固定长度的非重叠滑动窗口,每个窗口包含:

  • 输入序列input_ids):长度为context_length
  • 标签序列labels):长度为context_length + prediction_length,与输入序列错位 1 个时间步
  • 损失掩码loss_masks):标记哪些位置需要计算损失

1.2 关键参数

  • context_length:输入序列长度(历史信息)
  • prediction_length:预测序列长度(未来信息,默认为 0)
  • stride:窗口滑动步长(默认为窗口大小,即非重叠)

1.3 初始化逻辑

def __init__(self, dataset, context_length, prediction_length=0, stride=None):self.dataset = datasetself.context_length = context_lengthself.prediction_length = prediction_lengthself.window_size = context_length + prediction_lengthself.stride = stride or self.window_size  # 默认非重叠# 构建子序列索引列表self.sub_seq_indexes = []for seq_idx in range(len(dataset)):n_points = dataset.get_sequence_length_by_idx(seq_idx)if n_points < 2:continue# 添加初始窗口self.sub_seq_indexes.append((seq_idx, 0))# 添加后续窗口(按stride滑动)for offset in range(self.stride, n_points - self.window_size - 1 + 1, self.stride):self.sub_seq_indexes.append((seq_idx, offset))

1.4数据获取逻辑

def __getitem__(self, seq_idx):seq_i, offset_i = self.sub_seq_indexes[seq_idx]# 提取窗口数据(包含额外1个点用于错位)seq = self.dataset[seq_i][offset_i: offset_i + self.window_size + 1]seq = np.array(seq, dtype=np.float32)# 创建损失掩码(标记有效位置)loss_mask = np.ones(len(seq) - 1, dtype=np.int32)# 处理序列长度不足的情况(填充0)n_pad = self.window_size + 1 - len(seq)if n_pad > 0:seq = np.pad(seq, (0, n_pad), 'constant', constant_values=0)loss_mask = np.pad(loss_mask, (0, n_pad), 'constant', constant_values=0)return {'input_ids': seq[:-1],      # 输入序列'labels': seq[1:],         # 标签序列(错位1步)'loss_masks': loss_mask    # 损失掩码(忽略填充位置)}

2.增强类:UniversalTimeMoEWindowDataset

2.1 功能概述

实现了一种打包技术(pack technique),将多个短序列合并成一个固定长度的窗口,提高数据利用率和训练效率。

2.2 关键参数

  • shuffle:是否随机打乱序列顺序(默认为 False)
  • 其他参数与TimeMoEWindowDataset类似

2.3 初始化逻辑

def __init__(self, dataset, context_length, prediction_length=0, shuffle=False):self.dataset = datasetself.window_size = context_length + prediction_lengthself.window_info_list = []  # 存储窗口信息(每个窗口包含多个子序列片段)cur_window_info = []        # 当前窗口的子序列片段num_cur_remaining_points = self.window_size  # 当前窗口剩余可用长度# 遍历所有序列(可随机打乱)iterator = range(len(dataset))if shuffle:iterator = list(iterator)random.shuffle(iterator)for seq_idx in iterator:seq_len = dataset.get_sequence_length_by_idx(seq_idx)remaining_seq_len = seq_len# 将当前序列切割成多个片段,填充到窗口中while remaining_seq_len > 0:if remaining_seq_len < num_cur_remaining_points:# 当前序列剩余部分不足以填满窗口,全部加入cur_window_info.append((seq_idx, seq_len - remaining_seq_len, remaining_seq_len))num_cur_remaining_points -= remaining_seq_lenremaining_seq_len = 0else:# 当前序列剩余部分可以填满窗口,截取部分加入cur_window_info.append((seq_idx, seq_len - remaining_seq_len, num_cur_remaining_points))remaining_seq_len -= num_cur_remaining_points# 当前窗口已满,添加到结果列表并重置self.window_info_list.append(cur_window_info)num_cur_remaining_points = self.window_sizecur_window_info = []

2.4 数据获取逻辑

def __getitem__(self, window_idx):window_info = self.window_info_list[window_idx]seq = []# 从多个子序列片段构建完整窗口for seq_idx, start_idx, offset in window_info:part_seq = dataset[seq_idx][start_idx: start_idx + offset]seq.append(part_seq)# 合并所有片段if len(seq) == 1:seq = np.array(seq[0], dtype=np.float32)else:seq = np.concatenate(seq, axis=0, dtype=np.float32)return {'input_ids': seq[:-1],  # 输入序列'labels': seq[1:],     # 标签序列(错位1步)}

3.对比分析

  1. TimeMoEWindowDataset

    • 适用于长序列数据
    • 适合需要严格控制窗口独立性的场景
    • stride < window_size时支持重叠窗口,用于增强数据多样性
  2. UniversalTimeMoEWindowDataset

    • 适用于大量短序列数据
    • 通过打包技术减少填充,提高训练效率
    • 适合自回归模型(如 GPT 类模型),允许不同序列之间的信息流动

性能与内存权衡:

  • TimeMoEWindowDataset

    • 预计算所有窗口索引,内存开销较高(尤其是长序列)
    • 数据访问速度快(直接索引)
  • UniversalTimeMoEWindowDataset

    • 动态构建窗口,内存开销低
    • 数据访问时需要拼接多个片段,计算开销略高

4.总结

这两个类通过不同策略解决了时间序列数据的窗口化问题:

  • TimeMoEWindowDataset:提供简单直观的滑动窗口实现,支持灵活的重叠策略
  • UniversalTimeMoEWindowDataset:通过序列打包技术优化短序列处理,提高训练效率

两者共同构成了一个完整的时间序列数据预处理工具链,为后续模型训练提供了标准化的输入格式。

相关文章:

  • 【排错】dify1.3.1插件市场安装报错问题
  • 协议路由更改路径配置
  • 计算机设计大赛山东省赛区软件开发赛道线上答辩复盘
  • 记录一次window2012r2安装配置oracle11g的过程-出现的错误以及解决方法
  • GPT-4o, GPT 4.5, GPT 4.1, O3, O4-mini等模型的区别与联系
  • 嵌入式学习笔记 - 运算放大器的共模抑制比
  • Java 原生实现代码沙箱之Java 程序安全控制(OJ判题系统第2期)——设计思路、实现步骤、代码实现
  • Java基础:代理
  • JavaScript篇:async/await 错误处理指南:优雅捕获异常,告别失控的 Promise!
  • Linux系统下安装mongodb
  • ensp的华为小实验
  • JavaSE核心知识点02面向对象编程02-06(泛型)
  • Metasploit 4.22.7:企业级渗透测试新突破
  • Open CASCADE学习|管道壳体生成
  • AI Coding的发展之路:从概念到改变世界的旅程
  • 学习黑客5 分钟深入浅出理解Linux Packages Software Repos
  • GMS 与非 GMS:有何区别?
  • 【工具记录分享】提取bilibili视频字幕
  • Lingma:云效 MCP 使用
  • 容器填充函数fill和memset对比总结
  • 一生要出片的年轻人,买爆相机
  • 巴基斯坦空袭印度多地空军基地,巴战机进入印领空
  • 白宫启动“返乡计划” ,鼓励非法移民自愿离开美国
  • 复旦大学文科杰出教授裘锡圭逝世,享年90岁
  • 专访|高圆圆:像鸟儿一样,柔弱也自由
  • 4月外汇储备增加410亿美元,黄金储备连续6个月增加