当前位置: 首页 > news >正文

时间序列模型(1):LSTNet

LSTNet

LSTNet(Long- and Short-term Time-series Network)是一种专门用于时间序列预测的深度学习模型,结合了卷积神经网络(CNN)和循环神经网络(RNN)的优势,能够同时捕捉时间序列中的长期和短期依赖关系。

1. 模型背景

时间序列数据通常包含两种模式:

短期依赖:如季节性、周期性变化。

长期依赖:如趋势、突变点等。

传统模型(如ARIMA、RNN)往往难以同时捕捉这两种模式,而LSTNet通过多组件结构解决了这一问题。

2. 模型结构

LSTNet的核心结构包括以下几个模块:

在这里插入图片描述

(1) 卷积层(CNN)

作用:提取时间序列中的局部模式(短期依赖)。

实现:使用一维卷积核在时间维度上滑动,捕捉相邻时间点之间的关系。

优点:能够高效提取短期特征,适合处理高维时间序列。

(2) 循环层(RNN/GRU)

作用:捕捉时间序列中的长期依赖关系。

实现:通常使用GRU(门控循环单元)或LSTM(长短期记忆网络),以减少梯度消失问题。

优点:能够建模时间序列中的复杂时间依赖。

缺点:因为存在梯度消失和梯度爆炸,导致对长序列的处理能力有限。

(3) Skip-RNN

作用:专门捕捉时间序列中的周期性模式(如季节性)。

实现:跳跃连接(skip connection)将当前时间步与历史时间步直接关联,通过指定跳越时间步来周期性提取特征,适合处理周期性数据。

优点:显式建模周期性,提升对季节性数据的预测能力。

(4) 自回归组件(AR)

作用:捕捉线性依赖关系,增强模型的稳定性。

实现:使用传统的自回归模型(如ARIMA中的AR部分)对残差进行建模。

优点:防止深度学习模型过度依赖非线性特征,提升鲁棒性。

(5) 全连接层

作用:将CNN和RNN的输出进行整合,生成最终预测结果。

实现:通过全连接层将特征映射到目标输出维度。

3. 模型特点

多尺度建模:通过CNN和RNN的结合,同时捕捉短期和长期依赖。

周期性建模:Skip-RNN显式建模周期性模式,适合季节性数据。

4. 源码参数分析

import torch
import torch.nn as nn
import torch.nn.functional as F

# traffic数据集
class Model(nn.Module):
    def __init__(self, args, data):
        super(Model, self).__init__()
        self.use_cuda = args.cuda
        self.P = args.window; # 时间序列长度 default: 24 * 7
        self.m = data.m # 时间序列输入维度 default: 862
        # 输出维度
        self.hidR = args.hidRNN; # default: 100
        self.hidC = args.hidCNN; # default: 100
        self.hidS = args.hidSkip; # default: 10
        self.Ck = args.CNN_kernel; #kernel size, default: 6
        self.skip = args.skip; # 周期长度, default: 24
        self.pt = int((self.P - self.Ck)/self.skip) # # window在kernel作用下,以skip为周期的数据数量。周期数目。
        self.hw = args.highway_window  # highway通道的输出节点数目
        self.conv1 = nn.Conv2d(1, self.hidC, kernel_size = (self.Ck, self.m));
        self.GRU1 = nn.GRU(self.hidC, self.hidR);
        self.dropout = nn.Dropout(p = args.dropout);
        if (self.skip > 0):
            self.GRUskip = nn.GRU(self.hidC, self.hidS);
            self.linear1 = nn.Linear(self.hidR + self.skip * self.hidS, self.m);
        else:
            self.linear1 = nn.Linear(self.hidR, self.m);
        if (self.hw > 0):
            self.highway = nn.Linear(self.hw, 1);
        self.output = None;
        if (args.output_fun == 'sigmoid'):
            self.output = F.sigmoid;
        if (args.output_fun == 'tanh'):
            self.output = F.tanh;
 
    def forward(self, x):
        batch_size = x.size(0);
        #CNN
        c = x.view(-1, 1, self.P, self.m);
        c = F.relu(self.conv1(c)); # (32, 1, 168, 862) -> (32, 100, 163, 1)
        c = self.dropout(c);
        c = torch.squeeze(c, 3); # (32, 100, 163), 提取连续时间步的局部特征
        # RNN 
        r = c.permute(2, 0, 1).contiguous(); # (163, 32, 100)
        _, r = self.GRU1(r);
        r = self.dropout(torch.squeeze(r,0)); # (32, 100)
        
        #skip-rnn
        
        if (self.skip > 0):
            s = c[:,:, int(-self.pt * self.skip):].contiguous(); # (32, 100, 144)
            s = s.view(batch_size, self.hidC, self.pt, self.skip); # (32, 100, 6, 24),将144个时间步的数据,按照6个周期,每个周期24个时间步划分,捕捉周期性规律
            s = s.permute(2,0,3,1).contiguous(); # (6, 32, 24, 100)
            s = s.view(self.pt, batch_size * self.skip, self.hidC); # (6, 768, 100)
            _, s = self.GRUskip(s);
            s = s.view(batch_size, self.skip * self.hidS);
            s = self.dropout(s); # (32, 240)
            r = torch.cat((r,s),1); # (32, 340)
        
        res = self.linear1(r);
        
        # highway,模型线性AR
        if (self.hw > 0):
            z = x[:, -self.hw:, :];
            z = z.permute(0,2,1).contiguous().view(-1, self.hw);
            z = self.highway(z);
            z = z.view(-1,self.m);
            res = res + z;
            
        if (self.output):
            res = self.output(res); # (32, 862)
        return res;

相关文章:

  • 解决ubuntu(jetpack)系统下系统盘存储不够的
  • MongoDB备份与还原
  • 2025年第十届数维杯大学生数学建模挑战赛参赛规则
  • Windows根据文件名批量在文件夹里查找文件并复制出来,用WPF实现的详细步骤
  • 29.代码随想录算法训练营第二十九天|134. 加油站,135. 分发糖果,860. 柠檬水找零,406. 根据身高重建队列
  • [rust] rust学习
  • 【C语言系列】字符函数和字符串函数
  • QT:串口上位机
  • 电脑神器,轻松超越系统自带!
  • 【免费】2006-2020年各省单位GDP能耗增速数据
  • 每日学习之一万个为什么
  • MySQL的 where 1=1会不会影响性能?
  • Stable Diffusion/DALL-E 3图像生成优化策略
  • Linux:自动化构建-make/Makefile
  • 软件开发项目有哪些风险
  • Redis Sentinel (哨兵模式)深度解析:构建高可用分布式缓存系统的核心机制
  • 【大模型学习】第十七章 预训练技术综述
  • [翱捷]功能机 Wifi
  • Pygame实现射击鸭子游戏3-2
  • 根据 GPU 型号安装指定 CUDA 版本的详细步骤(附有CUDA版本对应torch版本的表格)
  • 王毅将主持召开第三次中国—太平洋岛国外长会
  • 钟南山谈新冠阳性率升高:可防可治不用慌,高危人群应重点关注
  • 安徽凤阳通报鼓楼瓦片脱落:2023年曾维修,已成立调查组
  • 讲述“外国货币上的中国故事”,《世界钱币上的中国印记》主题书刊出版发布
  • 证监会副主席李明:支持符合条件的外资机构申请新业务、设立新产品
  • 上海青少年书法学习园开园:少年以巨笔书写《祖国万岁》