Datawhale_PyPOTS_task6
基于自定义时间序列数据集的下游任务分析方法
方法一:基于插补数据的LSTM分类
-
数据加载器:定义了
LoadImputedDataAndLabel
类,用于将插补后的数据和标签转换为PyTorch的Dataset
对象,便于后续使用DataLoader
进行批量加载。 -
模型定义:
ClassificationLSTM
类定义了一个基于LSTM的分类模型,包含一个LSTM层和一个全连接层(fcn
),用于将LSTM的输出映射到类别概率。 -
训练过程:
-
设置了训练轮数(
n_epochs=20
)和早停机制(patience=5
)——防止过拟合。 -
使用Adam优化器,学习率为
1e-3
。optimizer = torch.optim.Adam(model.parameters(), 1e-3)
-
在每个epoch中,模型在训练集上进行训练,并在验证集上评估损失。如果验证损失在连续
patience
个epoch中没有减少,则提前停止训练。 -
最后,模型加载验证损失最低时的权重,并在测试集上进行评估。
-
方法二:PyPOTS中的TimesNet模型进行端到端学习的分类
-
模型初始化:
-
创建了
TimesNet
模型,参数包括时间步长(n_steps
)、特征数量(n_features
)、类别数量(n_classes
)等。 -
设置了模型的训练参数,如层数(
n_layers=2
)、隐藏单元数量(d_model=64
)、前馈网络维度(d_ffn=128
)等。 -
使用Adam优化器,学习率为
1e-3
。 -
设置了早停机制(
patience=5
)和训练轮数(epochs=20
)。 -
模型保存,并设置只保存最佳模型。
-
-
训练过程:
-
使用训练集和验证集进行训练。
-
每个epoch记录训练损失和验证损失,并根据验证损失进行早停。
-
最终保存最佳模型。
-
LSTM方法:
插补数据来补充数据集;但插补过程可能会引入误差
TimeNet方法:
直接带缺失值的原始数据来处理;
方法一类似于task2和task3的过程;
方法二则类似有task4的过程。
端到端学习:
BRITS模型
TimeNet模型
()