Datawhale_PyPOTS_task6
基于自定义时间序列数据集的下游任务分析方法
方法一:基于插补数据的LSTM分类
-  
数据加载器:定义了
LoadImputedDataAndLabel类,用于将插补后的数据和标签转换为PyTorch的Dataset对象,便于后续使用DataLoader进行批量加载。 -  
模型定义:
ClassificationLSTM类定义了一个基于LSTM的分类模型,包含一个LSTM层和一个全连接层(fcn),用于将LSTM的输出映射到类别概率。 -  
训练过程:
-  
设置了训练轮数(
n_epochs=20)和早停机制(patience=5)——防止过拟合。 -  
使用Adam优化器,学习率为
1e-3。optimizer = torch.optim.Adam(model.parameters(), 1e-3) -  
在每个epoch中,模型在训练集上进行训练,并在验证集上评估损失。如果验证损失在连续
patience个epoch中没有减少,则提前停止训练。 -  
最后,模型加载验证损失最低时的权重,并在测试集上进行评估。
 
 -  
 
方法二:PyPOTS中的TimesNet模型进行端到端学习的分类
-  
模型初始化:
-  
创建了
TimesNet模型,参数包括时间步长(n_steps)、特征数量(n_features)、类别数量(n_classes)等。 -  
设置了模型的训练参数,如层数(
n_layers=2)、隐藏单元数量(d_model=64)、前馈网络维度(d_ffn=128)等。 -  
使用Adam优化器,学习率为
1e-3。 -  
设置了早停机制(
patience=5)和训练轮数(epochs=20)。 -  
模型保存,并设置只保存最佳模型。
 
 -  
 -  
训练过程:
-  
使用训练集和验证集进行训练。
 -  
每个epoch记录训练损失和验证损失,并根据验证损失进行早停。
 -  
最终保存最佳模型。
 
 -  
 
LSTM方法:
插补数据来补充数据集;但插补过程可能会引入误差
TimeNet方法:
直接带缺失值的原始数据来处理;
方法一类似于task2和task3的过程;
方法二则类似有task4的过程。
端到端学习:
BRITS模型
TimeNet模型
()
