Datawhale PyPOTS时间序列5月第4次笔记
端到端学习:使用一个模型直接接受包含缺失值的数据。
brits_classification.py
完整代码如下:
# brits_classification.pyfrom benchpots.datasets import preprocess_physionet2012
from pypots.classification import BRITS
from pypots.nn.functional.classification import calc_binary_classification_metricsdef main():print("📥 正在加载 PhysioNet2012 数据集...")physionet2012_dataset = preprocess_physionet2012(subset="set-a", pattern="point", rate=0.1,)dataset_for_training = {"X": physionet2012_dataset['train_X'],"y": physionet2012_dataset['train_y'],}dataset_for_validating = {"X": physionet2012_dataset['val_X'],"y": physionet2012_dataset['val_y'],}dataset_for_testing = {"X": physionet2012_dataset['test_X'],"y": physionet2012_dataset['test_y'],}print("🧠 初始化并训练 BRITS 分类模型 ...")brits = BRITS(n_steps=physionet2012_dataset['n_steps'],n_features=physionet2012_dataset['n_features'],n_classes=physionet2012_dataset["n_classes"],rnn_hidden_size=128,epochs=20,patience=5,)brits.fit(dataset_for_training, dataset_for_validating)print("🔍 在测试集上进行预测 ...")brits_results = brits.predict(dataset_for_testing)brits_prediction = brits_results["classification"]print("📊 计算二分类性能指标(ROC-AUC & PR-AUC) ...")classification_metrics = calc_binary_classification_metrics(brits_prediction,dataset_for_testing["y"],)print(f"\n✅ BRITS 在测试集上的 ROC-AUC 为: {classification_metrics['roc_auc']:.4f}")print(f"✅ BRITS 在测试集上的 PR-AUC 为: {classification_metrics['pr_auc']:.4f}\n")if __name__ == "__main__":main()
运行结果:
(pypots-env) PS D:\Projects\pypots-experiments> python brits_classification.py████████╗██╗███╗ ███╗███████╗ ███████╗███████╗██████╗ ██╗███████╗███████╗ █████╗ ██╗
╚══██╔══╝██║████╗ ████║██╔════╝ ██╔════╝██╔════╝██╔══██╗██║██╔════╝██╔════╝ ██╔══██╗██║██║ ██║██╔████╔██║█████╗█████╗███████╗█████╗ ██████╔╝██║█████╗ ███████╗ ███████║██║██║ ██║██║╚██╔╝██║██╔══╝╚════╝╚════██║██╔══╝ ██╔══██╗██║██╔══╝ ╚════██║ ██╔══██║██║██║ ██║██║ ╚═╝ ██║███████╗ ███████║███████╗██║ ██║██║███████╗███████║██╗██║ ██║██║╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ ╚══════╝╚══════╝╚═╝ ╚═╝╚═╝╚══════╝╚══════╝╚═╝╚═╝ ╚═╝╚═╝
ai4ts v0.0.3 - building AI for unified time-series analysis, https://time-series.ai📥 正在加载 PhysioNet2012 数据集...
2025-05-20 10:37:39 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link:
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-05-20 10:37:39 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-05-20 10:37:39 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-05-20 10:37:39 [INFO]: Loaded successfully!
2025-05-20 10:37:46 [WARNING]: Note that physionet_2012 has sparse observations in the time series, hence we don't add additional missing values to the training dataset.
2025-05-20 10:37:46 [INFO]: 23589 values masked out in the val set as ground truth, take 10.12% of the original observed values
2025-05-20 10:37:46 [INFO]: 29292 values masked out in the test set as ground truth, take 10.12% of the original observed values
2025-05-20 10:37:46 [INFO]: Total sample number: 3997
2025-05-20 10:37:46 [INFO]: Training set size: 2557 (63.97%)
2025-05-20 10:37:46 [INFO]: Validation set size: 640 (16.01%)
2025-05-20 10:37:46 [INFO]: Test set size: 800 (20.02%)
2025-05-20 10:37:46 [INFO]: Number of steps: 48
2025-05-20 10:37:46 [INFO]: Number of features: 37
2025-05-20 10:37:46 [INFO]: Train set missing rate: 79.81%
2025-05-20 10:37:46 [INFO]: Validating set missing rate: 81.57%
2025-05-20 10:37:46 [INFO]: Test set missing rate: 81.70%
🧠 初始化并训练 BRITS 分类模型 ...
2025-05-20 10:37:46 [INFO]: No given device, using default device: cpu
2025-05-20 10:37:46 [WARNING]: ‼️ saving_path not given. Model files and tensorboard file will not be saved.
2025-05-20 10:37:46 [INFO]: Using customized CrossEntropy as the training loss function.
2025-05-20 10:37:46 [INFO]: Using customized CrossEntropy as the validation metric function.
2025-05-20 10:37:46 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,860
2025-05-20 10:38:20 [INFO]: Epoch 001 - training loss (CrossEntropy): 1.6426, validation CrossEntropy: 0.3454
2025-05-20 10:38:41 [INFO]: Epoch 002 - training loss (CrossEntropy): 1.3225, validation CrossEntropy: 0.3129
2025-05-20 10:39:01 [INFO]: Epoch 003 - training loss (CrossEntropy): 1.2331, validation CrossEntropy: 0.2936
2025-05-20 10:39:22 [INFO]: Epoch 004 - training loss (CrossEntropy): 1.1602, validation CrossEntropy: 0.2980
2025-05-20 10:39:42 [INFO]: Epoch 005 - training loss (CrossEntropy): 1.1155, validation CrossEntropy: 0.2747
2025-05-20 10:40:03 [INFO]: Epoch 006 - training loss (CrossEntropy): 1.0836, validation CrossEntropy: 0.2573
2025-05-20 10:40:23 [INFO]: Epoch 007 - training loss (CrossEntropy): 1.0511, validation CrossEntropy: 0.2454
2025-05-20 10:40:44 [INFO]: Epoch 008 - training loss (CrossEntropy): 1.0327, validation CrossEntropy: 0.2332
2025-05-20 10:41:05 [INFO]: Epoch 009 - training loss (CrossEntropy): 0.9995, validation CrossEntropy: 0.2043
2025-05-20 10:41:26 [INFO]: Epoch 010 - training loss (CrossEntropy): 0.9911, validation CrossEntropy: 0.2046
2025-05-20 10:41:46 [INFO]: Epoch 011 - training loss (CrossEntropy): 0.9635, validation CrossEntropy: 0.1807
2025-05-20 10:42:07 [INFO]: Epoch 012 - training loss (CrossEntropy): 0.9599, validation CrossEntropy: 0.1750
2025-05-20 10:42:27 [INFO]: Epoch 013 - training loss (CrossEntropy): 0.9384, validation CrossEntropy: 0.1563
2025-05-20 10:42:48 [INFO]: Epoch 014 - training loss (CrossEntropy): 0.9051, validation CrossEntropy: 0.1386
2025-05-20 10:43:08 [INFO]: Epoch 015 - training loss (CrossEntropy): 0.8942, validation CrossEntropy: 0.1226
2025-05-20 10:43:29 [INFO]: Epoch 016 - training loss (CrossEntropy): 0.8914, validation CrossEntropy: 0.1238
2025-05-20 10:43:49 [INFO]: Epoch 017 - training loss (CrossEntropy): 0.8703, validation CrossEntropy: 0.1163
2025-05-20 10:44:10 [INFO]: Epoch 018 - training loss (CrossEntropy): 0.8618, validation CrossEntropy: 0.1000
2025-05-20 10:44:30 [INFO]: Epoch 019 - training loss (CrossEntropy): 0.8520, validation CrossEntropy: 0.0887
2025-05-20 10:44:51 [INFO]: Epoch 020 - training loss (CrossEntropy): 0.8359, validation CrossEntropy: 0.0953
2025-05-20 10:44:51 [INFO]: Finished training. The best model is from epoch#19.
🔍 在测试集上进行预测 ...
📊 计算二分类性能指标(ROC-AUC & PR-AUC) ...✅ BRITS 在测试集上的 ROC-AUC 为: 0.5709
✅ BRITS 在测试集上的 PR-AUC 为: 0.4205