当前位置: 首页 > news >正文

Datawhale PyPOTS时间序列5月第4次笔记

端到端学习:使用一个模型直接接受包含缺失值的数据。

brits_classification.py

完整代码如下:

# brits_classification.pyfrom benchpots.datasets import preprocess_physionet2012
from pypots.classification import BRITS
from pypots.nn.functional.classification import calc_binary_classification_metricsdef main():print("📥 正在加载 PhysioNet2012 数据集...")physionet2012_dataset = preprocess_physionet2012(subset="set-a", pattern="point", rate=0.1,)dataset_for_training = {"X": physionet2012_dataset['train_X'],"y": physionet2012_dataset['train_y'],}dataset_for_validating = {"X": physionet2012_dataset['val_X'],"y": physionet2012_dataset['val_y'],}dataset_for_testing = {"X": physionet2012_dataset['test_X'],"y": physionet2012_dataset['test_y'],}print("🧠 初始化并训练 BRITS 分类模型 ...")brits = BRITS(n_steps=physionet2012_dataset['n_steps'],n_features=physionet2012_dataset['n_features'],n_classes=physionet2012_dataset["n_classes"],rnn_hidden_size=128,epochs=20,patience=5,)brits.fit(dataset_for_training, dataset_for_validating)print("🔍 在测试集上进行预测 ...")brits_results = brits.predict(dataset_for_testing)brits_prediction = brits_results["classification"]print("📊 计算二分类性能指标(ROC-AUC & PR-AUC) ...")classification_metrics = calc_binary_classification_metrics(brits_prediction,dataset_for_testing["y"],)print(f"\n✅ BRITS 在测试集上的 ROC-AUC 为: {classification_metrics['roc_auc']:.4f}")print(f"✅ BRITS 在测试集上的 PR-AUC 为: {classification_metrics['pr_auc']:.4f}\n")if __name__ == "__main__":main()

运行结果:

(pypots-env) PS D:\Projects\pypots-experiments> python brits_classification.py████████╗██╗███╗   ███╗███████╗    ███████╗███████╗██████╗ ██╗███████╗███████╗    █████╗ ██╗
╚══██╔══╝██║████╗ ████║██╔════╝    ██╔════╝██╔════╝██╔══██╗██║██╔════╝██╔════╝   ██╔══██╗██║██║   ██║██╔████╔██║█████╗█████╗███████╗█████╗  ██████╔╝██║█████╗  ███████╗   ███████║██║██║   ██║██║╚██╔╝██║██╔══╝╚════╝╚════██║██╔══╝  ██╔══██╗██║██╔══╝  ╚════██║   ██╔══██║██║██║   ██║██║ ╚═╝ ██║███████╗    ███████║███████╗██║  ██║██║███████╗███████║██╗██║  ██║██║╚═╝   ╚═╝╚═╝     ╚═╝╚══════╝    ╚══════╝╚══════╝╚═╝  ╚═╝╚═╝╚══════╝╚══════╝╚═╝╚═╝  ╚═╝╚═╝
ai4ts v0.0.3 - building AI for unified time-series analysis, https://time-series.ai📥 正在加载 PhysioNet2012 数据集...
2025-05-20 10:37:39 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link:
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-05-20 10:37:39 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-05-20 10:37:39 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-05-20 10:37:39 [INFO]: Loaded successfully!
2025-05-20 10:37:46 [WARNING]: Note that physionet_2012 has sparse observations in the time series, hence we don't add additional missing values to the training dataset.
2025-05-20 10:37:46 [INFO]: 23589 values masked out in the val set as ground truth, take 10.12% of the original observed values
2025-05-20 10:37:46 [INFO]: 29292 values masked out in the test set as ground truth, take 10.12% of the original observed values
2025-05-20 10:37:46 [INFO]: Total sample number: 3997
2025-05-20 10:37:46 [INFO]: Training set size: 2557 (63.97%)
2025-05-20 10:37:46 [INFO]: Validation set size: 640 (16.01%)
2025-05-20 10:37:46 [INFO]: Test set size: 800 (20.02%)
2025-05-20 10:37:46 [INFO]: Number of steps: 48
2025-05-20 10:37:46 [INFO]: Number of features: 37
2025-05-20 10:37:46 [INFO]: Train set missing rate: 79.81%
2025-05-20 10:37:46 [INFO]: Validating set missing rate: 81.57%
2025-05-20 10:37:46 [INFO]: Test set missing rate: 81.70%
🧠 初始化并训练 BRITS 分类模型 ...
2025-05-20 10:37:46 [INFO]: No given device, using default device: cpu
2025-05-20 10:37:46 [WARNING]: ‼️ saving_path not given. Model files and tensorboard file will not be saved.
2025-05-20 10:37:46 [INFO]: Using customized CrossEntropy as the training loss function.
2025-05-20 10:37:46 [INFO]: Using customized CrossEntropy as the validation metric function.
2025-05-20 10:37:46 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,860
2025-05-20 10:38:20 [INFO]: Epoch 001 - training loss (CrossEntropy): 1.6426, validation CrossEntropy: 0.3454
2025-05-20 10:38:41 [INFO]: Epoch 002 - training loss (CrossEntropy): 1.3225, validation CrossEntropy: 0.3129
2025-05-20 10:39:01 [INFO]: Epoch 003 - training loss (CrossEntropy): 1.2331, validation CrossEntropy: 0.2936
2025-05-20 10:39:22 [INFO]: Epoch 004 - training loss (CrossEntropy): 1.1602, validation CrossEntropy: 0.2980
2025-05-20 10:39:42 [INFO]: Epoch 005 - training loss (CrossEntropy): 1.1155, validation CrossEntropy: 0.2747
2025-05-20 10:40:03 [INFO]: Epoch 006 - training loss (CrossEntropy): 1.0836, validation CrossEntropy: 0.2573
2025-05-20 10:40:23 [INFO]: Epoch 007 - training loss (CrossEntropy): 1.0511, validation CrossEntropy: 0.2454
2025-05-20 10:40:44 [INFO]: Epoch 008 - training loss (CrossEntropy): 1.0327, validation CrossEntropy: 0.2332
2025-05-20 10:41:05 [INFO]: Epoch 009 - training loss (CrossEntropy): 0.9995, validation CrossEntropy: 0.2043
2025-05-20 10:41:26 [INFO]: Epoch 010 - training loss (CrossEntropy): 0.9911, validation CrossEntropy: 0.2046
2025-05-20 10:41:46 [INFO]: Epoch 011 - training loss (CrossEntropy): 0.9635, validation CrossEntropy: 0.1807
2025-05-20 10:42:07 [INFO]: Epoch 012 - training loss (CrossEntropy): 0.9599, validation CrossEntropy: 0.1750
2025-05-20 10:42:27 [INFO]: Epoch 013 - training loss (CrossEntropy): 0.9384, validation CrossEntropy: 0.1563
2025-05-20 10:42:48 [INFO]: Epoch 014 - training loss (CrossEntropy): 0.9051, validation CrossEntropy: 0.1386
2025-05-20 10:43:08 [INFO]: Epoch 015 - training loss (CrossEntropy): 0.8942, validation CrossEntropy: 0.1226
2025-05-20 10:43:29 [INFO]: Epoch 016 - training loss (CrossEntropy): 0.8914, validation CrossEntropy: 0.1238
2025-05-20 10:43:49 [INFO]: Epoch 017 - training loss (CrossEntropy): 0.8703, validation CrossEntropy: 0.1163
2025-05-20 10:44:10 [INFO]: Epoch 018 - training loss (CrossEntropy): 0.8618, validation CrossEntropy: 0.1000
2025-05-20 10:44:30 [INFO]: Epoch 019 - training loss (CrossEntropy): 0.8520, validation CrossEntropy: 0.0887
2025-05-20 10:44:51 [INFO]: Epoch 020 - training loss (CrossEntropy): 0.8359, validation CrossEntropy: 0.0953
2025-05-20 10:44:51 [INFO]: Finished training. The best model is from epoch#19.
🔍 在测试集上进行预测 ...
📊 计算二分类性能指标(ROC-AUC & PR-AUC) ...✅ BRITS 在测试集上的 ROC-AUC 为: 0.5709
✅ BRITS 在测试集上的 PR-AUC 为: 0.4205

相关文章:

  • Docker run -v 的 rw 和 ro 模式_docker ro
  • shp2pgsql 导入 Shp 到 PostGIS 空间数据库
  • MVDR源码(可直接运行)
  • Jmeter(一) - 环境搭建
  • 小白的进阶之路系列之二----人工智能从初步到精通pytorch中分类神经网络问题详解
  • 3D几何建模引擎3D ACIS Modeler核心功能深度解读
  • 视觉语言模型之困:当否定词成为理解的“盲区”
  • 【AI 大模型】盘古大模型简介 ( 创建空间 | 体验模型 | 部署模型 )
  • AMO——下层RL与上层模仿相结合的自适应运动优化:让人形行走操作(loco-manipulation)兼顾可行性和动力学约束
  • ⭐️白嫖的阿里云认证⭐️ 第二弹【课时3:大模型辅助内容生产场景】for 「大模型Clouder认证:利用大模型提升内容生产能力」
  • 第3天-python流程控制实例
  • 保证数据库 + redis在读写分离场景中事务的一致性
  • 隐形安全感
  • 1.3 C++之变量与数据类型
  • 【算法-栈】深入栈模拟题:从题型特征到实现技巧
  • Https流式输出一次输出一大段,一卡一卡的-解决方案
  • Spark离线数据处理实例
  • 【QT】ModbusTCP读写寄存器类封装
  • List介绍
  • 绿色云计算:数字化转型与可持续发展的完美融合
  • 可显著提高公交出行率,山东、浙江多县常态化实施城区公交免费
  • 李在明遭遇暗杀威胁,韩国警方锁定两名嫌疑人
  • 特朗普与普京就俄乌问题通话
  • 花旗回应减员传闻:持续评估人力资源战略,将为受影响的个人提供支持
  • 墨海军训练舰在纽约撞桥,墨总统:对遇难者表示悲痛,将跟进调查
  • 上海小学生暑(寒)托班会增设开办期数、延长办班时间吗?团市委回应