当前位置: 首页 > news >正文

用 Python 轻松实现时间序列预测:Darts 时间序列混合器(TSMixer)Time Series Mixer

文中内容仅限技术学习与代码实践参考,市场存在不确定性,技术分析需谨慎验证,不构成任何投资建议。

Darts

Darts 是一个 Python 库,用于对时间序列进行用户友好型预测和异常检测。它包含多种模型,从 ARIMA 等经典模型到深度神经网络。所有预测模型都能以类似 scikit-learn 的方式使用 fit()predict() 函数。该库还可以轻松地对模型进行回溯测试,将多个模型的预测结果结合起来,并将外部数据考虑在内。Darts 支持单变量和多变量时间序列和模型。基于 ML 的模型可以在包含多个时间序列的潜在大型数据集上进行训练,其中一些模型还为概率预测提供了丰富的支持。

时间序列混合器(TSMixer)

Time Series Mixer (TSMixer)

这个 notebook 逐步展示了如何使用 Darts 的 TSMixerModel 并将其与 TiDEModel 进行基准测试。

TSMixer(Time-series Mixer)是一种用于时间序列预测的全 MLP 架构。

它通过整合历史时间序列数据、未来已知输入和静态上下文信息来实现这一目标。该架构使用条件特征混合和混合层的组合来处理和组合这些不同类型的数据,以实现有效的预测。

在 Darts 中,该模型支持所有类型的协变量(过去的、未来的和/或静态的)。

在此处查看原始论文和模型描述。

据作者称,该模型在多变量预测任务上优于几种最先进的模型。

让我们看看它在 ETTh1 和 ETTh2 数据集上与 TideModel 的表现如何。

# 如果在本地工作,修复 python 路径
from utils import fix_pythonpath_if_working_locallyfix_pythonpath_if_working_locally()
%matplotlib inline
%load_ext autoreload
%autoreload 2
%matplotlib inline
import warningswarnings.filterwarnings("ignore")
import logginglogging.disable(logging.CRITICAL)import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from pytorch_lightning.callbacks.early_stopping import EarlyStoppingfrom darts import concatenate
from darts.dataprocessing.transformers.scaler import Scaler
from darts.datasets import ETTh1Dataset, ETTh2Dataset
from darts.metrics import mql
from darts.models import TiDEModel, TSMixerModel
from darts.utils.callbacks import TFMProgressBar
from darts.utils.likelihood_models.torch import QuantileRegression

数据加载和准备

我们考虑 ETTh1 和 ETTh2 数据集,它们包含电力变压器的每小时多变量数据(负载、油温……)。您可以在此处找到更多信息。

我们将为每个变压器时间序列添加静态信息,以标识它是 ETTh1 还是 ETTh2 变压器。TSMixer 和 TiDE 都可以利用这些信息。

series = []
for idx, ds in enumerate([ETTh1Dataset, ETTh2Dataset]):trafo = ds().load().astype(np.float32)trafo = trafo.with_static_covariates(pd.DataFrame({"transformer_id": [idx]}))series.append(trafo)
series[0].to_dataframe()
componentHUFLHULLMUFLMULLLUFLLULLOT
date
2016-07-01 00:00:005.8272.0091.5990.4624.2031.34030.531000
2016-07-01 01:00:005.6932.0761.4920.4264.1421.37127.787001
2016-07-01 02:00:005.1571.7411.2790.3553.7771.21827.787001
2016-07-01 03:00:005.0901.9421.2790.3913.8071.27925.044001
2016-07-01 04:00:005.3581.9421.4920.4623.8681.27921.948000
2018-06-26 15:00:00-1.6743.550-5.6152.1323.4721.52310.904000
2018-06-26 16:00:00-5.4924.287-9.1322.2743.5331.67511.044000
2018-06-26 17:00:002.8133.818-0.8172.0973.7161.52310.271000
2018-06-26 18:00:009.2433.8185.4722.0973.6551.4329.778000
2018-06-26 19:00:0010.1143.5506.1831.5643.7161.4629.567000

17420 行 × 7 列

在训练之前,我们将数据分为训练集、验证集和测试集。模型将从训练集中学习,使用验证集来确定何时停止训练,最后在测试集上进行评估。

train, val, test = [], [], []
for trafo in series:train_, temp = trafo.split_after(0.6)val_, test_ = temp.split_after(0.5)train.append(train_)val.append(val_)test.append(test_)

让我们看看每个变压器的第一个列 “HUFL” 的分割情况

show_col = "HUFL"
for idx, (train_, val_, test_) in enumerate(zip(train, val, test)):train_[show_col].plot(label=f"train_trafo_{idx}")val_[show_col].plot(label=f"val_trafo_{idx}")test_[show_col].plot(label=f"test_trafo_{idx}")

现在让我们对数据进行缩放。为了避免从验证集和测试集中泄露信息,我们根据训练集的特性对数据进行缩放。

scaler = Scaler()  # 默认使用 sklearn 的 MinMaxScaler
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

模型参数设置

样板代码并不有趣,尤其是在训练多个模型以比较性能的情况下。为了避免这种情况,我们使用一个通用配置,该配置可与任何 Darts TorchForecastingModel 一起使用。

关于这些参数的一些有趣之处:

  • 梯度裁剪: 通过在批次梯度上设置上限来减轻反向传播期间梯度爆炸的问题。

  • 学习率: 模型的大部分学习是在早期阶段完成的。随着训练的进行,降低学习率通常有助于微调模型。也就是说,它也可能导致严重的过拟合。

  • 早停: 为了避免过拟合,我们可以使用早停。它监控验证集上的指标,并在指标不再根据自定义条件改善时停止训练。

  • 似然和损失函数: 您可以使用 likelihood 使模型具有概率性,或使用 loss_fn 使其具有确定性。在此 notebook 中,我们使用分位数回归训练概率模型。

  • 可逆实例归一化: 使用可逆实例归一化,在大多数情况下可以提高模型性能。

  • 编码器: 我们可以对时间轴/日历信息进行编码,并使用 add_encoders 将它们作为过去或未来的协变量。在这里,我们将小时、星期几和月份的循环编码作为未来的协变量

def create_params(input_chunk_length: int,output_chunk_length: int,full_training=True,
):# 早停:此设置在验证损失在 10 个 epoch 内未减少超过 1e-5 时停止训练early_stopper = EarlyStopping(monitor="val_loss",patience=10,min_delta=1e-5,mode="min",)# PyTorch Lightning 训练器参数(您可以添加任何自定义回调)if full_training:limit_train_batches = Nonelimit_val_batches = Nonemax_epochs = 200batch_size = 256else:limit_train_batches = 20limit_val_batches = 10max_epochs = 40batch_size = 64# 仅显示训练和预测进度条progress_bar = TFMProgressBar(enable_sanity_check_bar=False, enable_validation_bar=False)pl_trainer_kwargs = {"gradient_clip_val": 1,"max_epochs": max_epochs,"limit_train_batches": limit_train_batches,"limit_val_batches": limit_val_batches,"accelerator": "auto","callbacks": [early_stopper, progress_bar],}# 优化器设置,默认使用 Adam# optimizer_cls = torch.optim.Adamoptimizer_kwargs = {"lr": 1e-4,}# 学习率调度器lr_scheduler_cls = torch.optim.lr_scheduler.ExponentialLRlr_scheduler_kwargs = {"gamma": 0.999}# 对于概率模型,我们使用分位数回归,并将 `loss_fn` 设置为 `None`likelihood = QuantileRegression()loss_fn = Nonereturn {"input_chunk_length": input_chunk_length,  # 回溯窗口"output_chunk_length": output_chunk_length,  # 预测/前瞻窗口"use_reversible_instance_norm": True,"optimizer_kwargs": optimizer_kwargs,"pl_trainer_kwargs": pl_trainer_kwargs,"lr_scheduler_cls": lr_scheduler_cls,"lr_scheduler_kwargs": lr_scheduler_kwargs,"likelihood": likelihood,  # 使用 `likelihood` 进行概率预测"loss_fn": loss_fn,  # 使用 `loss_fn` 进行确定性模型"save_checkpoints": True,  # 检查点以检索最佳模型状态,"force_reset": True,"batch_size": batch_size,"random_state": 42,"add_encoders": {"cyclic": {"future": ["hour", "dayofweek", "month"]}  # 将循环时间轴编码添加为未来协变量},}

模型配置

让我们使用过去一周的小时数据作为回溯窗口( input_chunk_length),并训练一个概率模型直接预测接下来的 24 小时( output_chunk_length)。此外,我们告诉模型使用静态信息。为了保持 notebook 简单,我们将设置 full_training=False。要获得更好的性能,请设置 full_training=True

除此之外,我们使用我们的辅助函数来设置所有常见的模型参数。

input_chunk_length = 7 * 24
output_chunk_length = 24
use_static_covariates = True
full_training = False
# 创建模型
model_tsm = TSMixerModel(**create_params(input_chunk_length,output_chunk_length,full_training=full_training,),use_static_covariates=use_static_covariates,model_name="tsm",
)
model_tide = TiDEModel(**create_params(input_chunk_length,output_chunk_length,full_training=full_training,),use_static_covariates=use_static_covariates,model_name="tide",
)
models = {"TSM": model_tsm,"TiDE": model_tide,
}

模型训练

现在让我们训练所有模型。使用早停时,保存检查点非常重要。这使我们能够继续训练超过最佳模型配置,然后在训练完成后恢复最优权重。

# 训练模型并从其最佳状态/检查点加载模型
for model_name, model in models.items():model.fit(series=train,val_series=val,)# 从检查点加载返回一个新的模型对象,我们将其存储在 models 字典中models[model_name] = model.load_from_checkpoint(model_name=model.model_name, best=True)

对概率模型进行回溯测试

让我们配置预测。对于此示例,我们将:

  • 使用预训练模型在测试集上生成历史预测。每个预测覆盖 24 小时的范围,两个连续预测之间的时间也是 24 小时。这将为我们提供每个变压器 276 个多变量预测来评估模型!

  • 为每个预测点生成500 个随机样本(因为我们已经训练了概率模型)

  • 使用平均分位数损失mql())对一些分位数评估/回溯测试概率历史预测。

我们还将创建一些辅助函数来生成预测、计算回溯测试并可视化预测。

# 配置概率预测
num_samples = 500
forecast_horizon = output_chunk_length# 在这些分位数上计算平均分位数损失
evaluate_quantiles = [0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95]def historical_forecasts(model):"""为每个变压器生成概率历史预测并返回逆变换的结果。每个预测覆盖 24 小时(forecast_horizon)。两个预测之间的时间(stride)也是 24 小时。"""hfc = model.historical_forecasts(series=test,forecast_horizon=forecast_horizon,stride=forecast_horizon,last_points_only=False,retrain=False,num_samples=num_samples,verbose=True,)return scaler.inverse_transform(hfc)def backtest(model, hfc, name):"""使用一组分位数上的平均分位数损失(MQL)评估概率历史预测。"""# 添加指标特定的 kwargsmetric_kwargs = [{"q": q} for q in evaluate_quantiles]metrics = [mql for _ in range(len(evaluate_quantiles))]bt = model.backtest(series=series,historical_forecasts=hfc,last_points_only=False,metric=metrics,metric_kwargs=metric_kwargs,verbose=True,)bt = pd.DataFrame(bt,columns=[f"q_{q}" for q in evaluate_quantiles],index=[f"{trafo}_{name}" for trafo in ["ETTh1", "ETTh2"]],)return btdef generate_plots(n_days, hfcs):"""针对每个模型、变压器和变压器特征,绘制概率预测与真实情况。"""# 将历史预测连接成连续的时间序列# (因为 forecast_horizon=stride,所以可以工作)hfcs_plot = {}for model_name, hfc_model in hfcs.items():hfcs_plot[model_name] = [concatenate(hfc_series[-n_days:], axis=0) for hfc_series in hfc_model]# 记住用于绘制目标序列的开始和结束点hfc_ = hfcs_plot[model_name][0]start, end = hfc_.start_time(), hfc_.end_time()# 对于每个目标列...for col in series[0].columns:fig, axes = plt.subplots(ncols=2, figsize=(12, 6))# ... 并且对于每个变压器...for trafo_idx, trafo in enumerate(series):trafo[col][start:end].plot(label="ground truth", ax=axes[trafo_idx])# ... 绘制每个模型的历史预测for model_name, hfc in hfcs_plot.items():hfc[trafo_idx][col].plot(label=model_name + "_q0.05-q0.95", ax=axes[trafo_idx])axes[trafo_idx].set_title(f"ETTh{trafo_idx + 1}: {col}")plt.show()

好的,现在我们准备评估模型

bts = {}
hfcs = {}
for model_name, model in models.items():print(f"Model: {model_name}")print("Generating historical forecasts..")hfcs[model_name] = historical_forecasts(models[model_name])print("Evaluating historical forecasts..")bts[model_name] = backtest(models[model_name], hfcs[model_name], model_name)
Model: TSM
Generating historical forecasts..
Evaluating historical forecasts..
Model: TiDE
Generating historical forecasts..
Evaluating historical forecasts..

让我们看看它们的表现如何。

注意:当设置 full_training=True 时,这些结果可能会改善/改变

bt_df = pd.concat(bts.values(), axis=0).sort_index()
bt_df
q_0.05q_0.1q_0.2q_0.5q_0.8q_0.9q_0.95
ETTh1_TSM0.5017720.7695451.1361411.5684391.0988470.7218350.442062
ETTh1_TiDE0.5737160.8854521.2986721.6718701.1515010.7275150.446724
ETTh2_TSM0.6591871.0306551.5086281.9329231.3179600.8571470.524620
ETTh2_TiDE0.6272510.9821141.4508931.8971171.3236610.8622390.528638

回溯测试为我们提供了每个变压器和模型的所有变压器特征在所选分位数上的平均分位数损失。值越低越好。q_0.5 与中位数预测与真实值之间的平均绝对误差(MAE)相同。

两个模型似乎表现相当。那么所有分位数的平均表现如何?

bt_df.mean(axis=1)
ETTh1_TSM     0.891234
ETTh1_TiDE    0.965064
ETTh2_TSM     1.118732
ETTh2_TiDE    1.095988
dtype: float64

这里的结果也非常相似。看起来 TSMixer 在 ETTh1 上表现更好,而 TiDEModel 在 ETTh2 上表现更好。

最后但并非最不重要的是,让我们看看测试集中最后 n_days=3 天的预测。

注意:当 full_training=True 时,预测区间预计会变窄。

generate_plots(n_days=3, hfcs=hfcs)

img

img

img

img

img

img

img

结果

在这种情况下,TSMixerTiDEModel 的表现都相当不错。请记住,我们仅对数据进行了部分训练,并且使用了默认模型参数,没有任何超参数调整。

以下是一些进一步提高性能的方法:

  • 设置 full_training=True

  • 执行超参数调整

  • 添加更多协变量(我们仅添加了日历信息的循环编码)

风险提示与免责声明
本文内容基于公开信息研究整理,不构成任何形式的投资建议。历史表现不应作为未来收益保证,市场存在不可预见的波动风险。投资者需结合自身财务状况及风险承受能力独立决策,并自行承担交易结果。作者及发布方不对任何依据本文操作导致的损失承担法律责任。市场有风险,投资须谨慎。

http://www.dtcms.com/a/304036.html

相关文章:

  • WAIC 2025观察:昇腾助力AI融入多元化生活场景
  • sqli-labs通关笔记-第25关GET字符注入(过滤or和and 脚本法)
  • 数据手套五指触觉灵巧手遥操作方案
  • Hyperchain安全与隐私机制详解
  • Windows 下使用 Ollama 调试大模型
  • 故障排除---Operator部署Prometheus无法NodePort访问
  • zoho crm为什么xx是deal的关联对象但是调用函数时报错说不是关联对象
  • 译|生存分析Survival Analysis案例入门讲解(一)
  • 电磁兼容(EMC):整改案例(十三)屏蔽外壳开孔解决433MHz无线通信问题
  • 【硬件-笔试面试题】硬件/电子工程师,笔试面试题-45,(知识点:负反馈的作用,基础理解,干扰和噪声的抑制)
  • React--》实现 PDF 文件的预览操作
  • WisFile(文件整理工具) v1.2.19 免费版
  • 自然语言处理NLP(3)
  • Mac m系列芯片安装node14版本使用nvm + Rosetta 2
  • 【第四章:大模型(LLM)】01.神经网络中的 NLP-(3)文本情感分类实战
  • 网络安全运维面试准备
  • 全自动植树机solidwoeks图纸cad【7张】三维图+设计说明说
  • 第二十二天(数据结构,无头节点的单项链表)
  • 去掉ansible的相关警告信息
  • RK3568下的进程间广播通信:用C语言构建简单的中心服务器
  • 人工智能驱动的自动化革命:重塑工作与社会的未来图景
  • XtestRunner一个比较好用好看的生成测试报告的工具
  • AI Agent推动搜索引擎优化自动化进程
  • python-网络编程
  • 【刷题】东方博宜 1503-排序 容器排序
  • 【数据结构】真题 2016
  • 怎么理解使用MQ解决分布式事务 -- 以kafka为例
  • ABP VNext + GraphQL Federation:跨微服务联合 Schema 分层
  • Java 课程,每天解读一个简单Java之判断101-200之间有多少个素数,并输出所有素数。
  • 如何制定项目计划?核心要点