当前位置: 首页 > news >正文

LSTM 单变量时序预测—pytorch

工具包导入+数据读取
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
df = pd.read_csv('/opt/cyc/week_task/xiaoliangyuce/data_qingxi/0804/data_0803_300_new.csv')
df1=df[df['uuid']=='B0F59QC63ZUS']
df2=df1[['sale_list_time','sale_list_rank']]
数据集划分+数据集归一化+滑动窗口构造
#划分数据集
train_data,test_data=df2['sale_list_rank'][:-20],df2['sale_list_rank'][-20:]
#训练集归一化
scaler=MinMaxScaler()
train_data_scale=scaler.fit_transform(train_data.values.reshape(-1,1))
test_data_scale=scaler.fit_transform(test_data.values.reshape(-1,1))#构造滑动窗口
def slid_window_data(data,window_size):X,Y=[],[]for i in range(len(data)-window_size):X.append(data[i:i+window_size])Y.append(data[i+window_size:i+window_size+1])return np.array(X),np.array(Y)
#lstm需要的形状:(样本数,时序长度,特征数)"""[1,2,3,4,5]:滑动窗口为3[1,2,3][2,3,4][4][5]最后得到两条数据,形状[2,3,1]"""
X_train,Y_train=slid_window_data(data=train_data_scale,window_size=20)
lstm模型
class LSTM_MODEL(nn.Module):def __init__(self, input_size=1, hidden_size=100,num_layers=1):super(LSTM_MODEL,self).__init__()self.hidden_size=hidden_sizeself.lstm=nn.LSTM(input_size,hidden_size,num_layers, batch_first=True)self.fc=nn.Linear(hidden_size,1)def forward(self,x):out,_=self.lstm(x)batch_size,seq_len,hidden_size=out.shape#seq_len:序列长度,在NLP中就是句子长度,一般都会用pad_sequence补齐长度#batch:每次喂给网络的数据条数,在NLP中就是一次喂给网络多少个句子#input_size:特征维度,和前面定义网络结构的input_size一致。x=self.fc(out)x = x[:,-1,:]return x
模型训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
Y_train = torch.tensor(Y_train, dtype=torch.float32).to(device)
model=LSTM_MODEL()
model = model.to(device) 
criterion=torch.nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
epoch=100
for epoch in range(epoch):model.train()output=model(X_train)loss=criterion(output,Y_train)loss.backward()optimizer.step()print(f"Epoch [{epoch+1}],LOSS:{loss.item():.4f}")
预测未来值
#制作未来预测数据输入
fur_len=20
train_data_normalized = torch.FloatTensor(train_data_scale).view(-1)
test_input=train_data_normalized[-fur_len:].tolist()
model.eval()
with torch.no_grad():model.hidden_cell = (torch.zeros(1, 1, model.hidden_size).to(device),torch.zeros(1, 1, model.hidden_size).to(device))for i in range(fur_len):#print(test_input[-fur_len:])seq = torch.FloatTensor(test_input[-fur_len:])seq = seq.to(device).unsqueeze(0).unsqueeze(2)  # [1, time_step, 1]test_input.append(model(seq).item())
test_input[fur_len:]
可视化对比
actual_predictions = scaler.inverse_transform(np.array(test_input[fur_len:] ).reshape(-1, 1))
plt.plot(list(range(len(test_data))),test_data,'ro-' )
plt.plot(list(range(len(actual_predictions))),actual_predictions,'bo-' )
plt.legend(["true","pred"])
plt.show()

http://www.dtcms.com/a/319782.html

相关文章:

  • vscode+latex本地英文期刊环境配置
  • VScode使用jupyter notebook,配置内核报错没有torch解决
  • 如何委托第三方检测机构做软件测试?
  • 鸿蒙 - 分享功能
  • 直播预告|鸿蒙生态下的 Flutter 开发实战
  • 非化学冷却塔水处理解决方案:绿色工业时代的革新引擎
  • Elasticsearch 文档分词器
  • 神经网络入门指南:从零理解 PyTorch 的核心思想
  • 2025 五大商旅平台管控力解析:合规要求下的商旅管理新范式
  • Flutter 布局控件使用详解
  • 【java基础|第十六篇】面向对象(六)——抽象和接口
  • Java-JVM探析
  • 参考平面与返回电流
  • BMS保护板测试仪:电池安全管理的“质检卫士”|深圳鑫达能
  • Java爬虫性能优化:多线程抓取JSP动态数据实践
  • 键盘+系统+软件等快捷键大全
  • RK3568笔记九十八:使用Qt实现RTMP拉流显示
  • FluentUI-main的详解
  • MyBatis联合查询
  • windows有一个企业微信安装包,脚本执行并安装到d盘。
  • 我的世界Java版1.21.4的Fabric模组开发教程(十七)自定义维度
  • PCL提取平面上的圆形凸台特征
  • WindowsLinux系统 安装 CUDA 和 cuDNN
  • 从库存一盘货到全域智能铺货:巨益科技全渠道平台助力品牌业财一体化升级
  • 电子基石:硬件工程师的器件手册 (九) - DC-DC拓扑:电能转换的魔术师
  • 线上业务突然流量掉 0 ?一次 DNS 污染排查与自救实录
  • Qt中类提升后不显示问题
  • 纷享销客前端实习一面
  • 数据结构(五):顺序循环队列与哈希表
  • 纪念《信号与系统》拉普拉斯变换、Z变换之前内容学完