TIME - MoE 模型代码 3.3——Time-MoE-main/time_moe/datasets/time_moe_window_dataset.py
源码:https://github.com/Time-MoE/Time-MoE
这段代码实现了两个用于时间序列数据处理的窗口化数据集类,主要用于将长序列切割成固定长度的子序列,为模型训练提供合适的输入格式。
1. 核心类:TimeMoEWindowDataset
1.1 功能概述
将长时间序列转换为固定长度的非重叠滑动窗口,每个窗口包含:
- 输入序列(
input_ids
):长度为context_length
- 标签序列(
labels
):长度为context_length + prediction_length
,与输入序列错位 1 个时间步 - 损失掩码(
loss_masks
):标记哪些位置需要计算损失
1.2 关键参数
context_length
:输入序列长度(历史信息)prediction_length
:预测序列长度(未来信息,默认为 0)stride
:窗口滑动步长(默认为窗口大小,即非重叠)
1.3 初始化逻辑
def __init__(self, dataset, context_length, prediction_length=0, stride=None):self.dataset = datasetself.context_length = context_lengthself.prediction_length = prediction_lengthself.window_size = context_length + prediction_lengthself.stride = stride or self.window_size # 默认非重叠# 构建子序列索引列表self.sub_seq_indexes = []for seq_idx in range(len(dataset)):n_points = dataset.get_sequence_length_by_idx(seq_idx)if n_points < 2:continue# 添加初始窗口self.sub_seq_indexes.append((seq_idx, 0))# 添加后续窗口(按stride滑动)for offset in range(self.stride, n_points - self.window_size - 1 + 1, self.stride):self.sub_seq_indexes.append((seq_idx, offset))
1.4数据获取逻辑
def __getitem__(self, seq_idx):seq_i, offset_i = self.sub_seq_indexes[seq_idx]# 提取窗口数据(包含额外1个点用于错位)seq = self.dataset[seq_i][offset_i: offset_i + self.window_size + 1]seq = np.array(seq, dtype=np.float32)# 创建损失掩码(标记有效位置)loss_mask = np.ones(len(seq) - 1, dtype=np.int32)# 处理序列长度不足的情况(填充0)n_pad = self.window_size + 1 - len(seq)if n_pad > 0:seq = np.pad(seq, (0, n_pad), 'constant', constant_values=0)loss_mask = np.pad(loss_mask, (0, n_pad), 'constant', constant_values=0)return {'input_ids': seq[:-1], # 输入序列'labels': seq[1:], # 标签序列(错位1步)'loss_masks': loss_mask # 损失掩码(忽略填充位置)}
2.增强类:UniversalTimeMoEWindowDataset
2.1 功能概述
实现了一种打包技术(pack technique),将多个短序列合并成一个固定长度的窗口,提高数据利用率和训练效率。
2.2 关键参数
shuffle
:是否随机打乱序列顺序(默认为 False)- 其他参数与
TimeMoEWindowDataset
类似
2.3 初始化逻辑
def __init__(self, dataset, context_length, prediction_length=0, shuffle=False):self.dataset = datasetself.window_size = context_length + prediction_lengthself.window_info_list = [] # 存储窗口信息(每个窗口包含多个子序列片段)cur_window_info = [] # 当前窗口的子序列片段num_cur_remaining_points = self.window_size # 当前窗口剩余可用长度# 遍历所有序列(可随机打乱)iterator = range(len(dataset))if shuffle:iterator = list(iterator)random.shuffle(iterator)for seq_idx in iterator:seq_len = dataset.get_sequence_length_by_idx(seq_idx)remaining_seq_len = seq_len# 将当前序列切割成多个片段,填充到窗口中while remaining_seq_len > 0:if remaining_seq_len < num_cur_remaining_points:# 当前序列剩余部分不足以填满窗口,全部加入cur_window_info.append((seq_idx, seq_len - remaining_seq_len, remaining_seq_len))num_cur_remaining_points -= remaining_seq_lenremaining_seq_len = 0else:# 当前序列剩余部分可以填满窗口,截取部分加入cur_window_info.append((seq_idx, seq_len - remaining_seq_len, num_cur_remaining_points))remaining_seq_len -= num_cur_remaining_points# 当前窗口已满,添加到结果列表并重置self.window_info_list.append(cur_window_info)num_cur_remaining_points = self.window_sizecur_window_info = []
2.4 数据获取逻辑
def __getitem__(self, window_idx):window_info = self.window_info_list[window_idx]seq = []# 从多个子序列片段构建完整窗口for seq_idx, start_idx, offset in window_info:part_seq = dataset[seq_idx][start_idx: start_idx + offset]seq.append(part_seq)# 合并所有片段if len(seq) == 1:seq = np.array(seq[0], dtype=np.float32)else:seq = np.concatenate(seq, axis=0, dtype=np.float32)return {'input_ids': seq[:-1], # 输入序列'labels': seq[1:], # 标签序列(错位1步)}
3.对比分析
-
TimeMoEWindowDataset
:- 适用于长序列数据
- 适合需要严格控制窗口独立性的场景
- 当
stride < window_size
时支持重叠窗口,用于增强数据多样性
-
UniversalTimeMoEWindowDataset
:- 适用于大量短序列数据
- 通过打包技术减少填充,提高训练效率
- 适合自回归模型(如 GPT 类模型),允许不同序列之间的信息流动
性能与内存权衡:
-
TimeMoEWindowDataset
:- 预计算所有窗口索引,内存开销较高(尤其是长序列)
- 数据访问速度快(直接索引)
-
UniversalTimeMoEWindowDataset
:- 动态构建窗口,内存开销低
- 数据访问时需要拼接多个片段,计算开销略高
4.总结
这两个类通过不同策略解决了时间序列数据的窗口化问题:
TimeMoEWindowDataset
:提供简单直观的滑动窗口实现,支持灵活的重叠策略UniversalTimeMoEWindowDataset
:通过序列打包技术优化短序列处理,提高训练效率
两者共同构成了一个完整的时间序列数据预处理工具链,为后续模型训练提供了标准化的输入格式。