当前位置: 首页 > news >正文

深度学习打卡第R4周:LSTM-火灾温度预测

  •  🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

    目录

    一、前期准备

    1.1 导入数据

    1.2 数据可视化

    二、构建数据集

    2.1 数据集预处理

    2.2 设置X,y

    2.3 划分数据集

    三、模型训练

    3.1 构建模型

    3.2 定义训练函数

    3.3 定义测试函数

    3.4 正式训练模型

    四、模型评估

    4.1 LOSS图

    4.2 调用模型进行预测

    4.3 R2值评估


一、前期准备

import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
from torch import nn

1.1 导入数据

data = pd.read_csv("woodpine2.csv")
data

1.2 数据可视化

import matplotlib.pyplot as plt
import seaborn as snsplt.rcParams['savefig.dpi'] = 500  # 图片像素
plt.rcParams['figure.dpi'] = 500  # 分辨率fig, ax = plt.subplots(1, 3, constrained_layout=True, figsize=(14, 3))sns.lineplot(data=data["Tem1"], ax=ax[0])
sns.lineplot(data=data["CO 1"], ax=ax[1])
sns.lineplot(data=data["Soot 1"], ax=ax[2])
plt.show()

dataFrame = data.iloc[:,1:]
dataFrame

二、构建数据集

2.1 数据集预处理

from sklearn.preprocessing import MinMaxScalerdataFrame = data.iloc[:,1:].copy()
sc = MinMaxScaler(feature_range=(0, 1))  # 将数据归一化,范围是0到1
for i in ['CO 1', 'Soot 1', 'Tem1']:dataFrame[i] = sc.fit_transform(dataFrame[i].values.reshape(-1, 1))
dataFrame.shape

2.2 设置X,y

width_x = 8
width_y = 1# 取前8个时间段的Tem1、CO 1、Soot 1为X,第9个时间段的Tem1为y。
X = []
y = []in_start = 0
for _, _ in data.iterrows():in_end = in_start + width_xout_end = in_end + width_yif out_end < len(dataFrame):X_ = np.array(dataFrame.iloc[in_start:in_end, ])y_ = np.array(dataFrame.iloc[in_end:out_end, 0])X.append(X_)y.append(y_)in_start += 1X = np.array(X)
y = np.array(y).reshape(-1,1,1)
X.shape, y.shape

2.3 划分数据集

X_train = torch.tensor(np.array(X[:5000]), dtype=torch.float32)
y_train = torch.tensor(np.array(y[:5000]), dtype=torch.float32)
X_test  = torch.tensor(np.array(X[5000:]), dtype=torch.float32)
y_test  = torch.tensor(np.array(y[5000:]), dtype=torch.float32)
X_train.shape, y_train.shape
from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)test_dl  = DataLoader(TensorDataset(X_test, y_test),batch_size=64,shuffle=False)

三、模型训练

3.1 构建模型

class model_lstm(nn.Module):def __init__(self):super(model_lstm, self).__init__()self.lstm0 = nn.LSTM(input_size=3, hidden_size=320,num_layers=1, batch_first=True)self.lstm1 = nn.LSTM(input_size=320, hidden_size=320,num_layers=1, batch_first=True)self.fc0  = nn.Linear(320, 1)def forward(self, x):out, hidden1 = self.lstm0(x)out, _ = self.lstm1(out, hidden1)out  = self.fc0(out)return out[:, -1:, :]  #取1个预测值,否则经过lstm会得到8*1个预测
model = model_lstm()
model

3.2 定义训练函数

# 训练循环
import copy
def train(train_dl, model, loss_fn, opt, lr_scheduler=None):size      = len(train_dl.dataset)num_batches = len(train_dl)train_loss = 0  # 初始化训练损失和正确率for x, y in train_dl:x, y = x.to(device), y.to(device)# 计算预测误差pred = model(x)         # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距# 反向传播opt.zero_grad() # grad属性归零loss.backward() # 反向传播opt.step()      # 每一步自动更新# 记录losstrain_loss += loss.item()if lr_scheduler is not None:lr_scheduler.step()print("learning rate = {:.5f}".format(opt.param_groups[0]['lr']), end=" ")train_loss /= num_batchesreturn train_loss

3.3 定义测试函数

def test (dataloader, model, loss_fn):size      = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)        # 批次数目test_loss  = 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)# 计算lossy_pred = model(x)loss      = loss_fn(y_pred, y)test_loss += loss.item()test_loss /= num_batchesreturn test_loss

3.4 正式训练模型

#设置GPU训练
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
#训练模型
model = model_lstm()
model = model.to(device)
loss_fn    = nn.MSELoss() # 创建损失函数
learn_rate = 1e-1  # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate,weight_decay=1e-4)
epochs     = 50
train_loss = []
test_loss  = []
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,epochs, last_epoch=-1)for epoch in range(epochs):model.train()epoch_train_loss = train(train_dl, model, loss_fn, opt, lr_scheduler)model.eval()epoch_test_loss = test(test_dl, model, loss_fn)train_loss.append(epoch_train_loss)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_loss:{:.5f}, Test_loss:{:.5f}')print(template.format(epoch+1, epoch_train_loss, epoch_test_loss))print("="*20, 'Done', "="*20)

四、模型评估

4.1 LOSS图

import matplotlib.pyplot as plt
from datetime import datetime
current_time = datetime.now()  # 获取当前时间plt.figure(figsize=(5, 3), dpi=120)plt.plot(train_loss, label='LSTM Training Loss')
plt.plot(test_loss, label='LSTM Validation Loss')plt.title('Training and Validation Loss')
plt.xlabel(current_time)  # 打卡请带上时间戳,否则代码截图无效
plt.legend()
plt.show()

4.2 调用模型进行预测

predicted_y_lstm = sc.inverse_transform(model(X_test).detach().numpy().reshape(-1,1))
y_test_1         = sc.inverse_transform(y_test.reshape(-1,1))
y_test_one       = [i[0] for i in y_test_1]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]plt.figure(figsize=(5, 3), dpi=120)
# 画出真实数据和预测数据的对比曲线
plt.plot(y_test_one[:2000], color='red', label='real_temp')
plt.plot(predicted_y_lstm_one[:2000], color='blue', label='prediction')plt.title('Title')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.show()

4.3 R2值评估

from sklearn import metrics
"""
RMSE :均方根误差 -----> 对均方误差开方
R2  :决定系数,可以简单理解为反映模型拟合优度的重要的统计量
"""
RMSE_lstm = metrics.mean_squared_error(predicted_y_lstm_one, y_test_1)**0.5
R2_lstm   = metrics.r2_score(predicted_y_lstm_one, y_test_1)print('均方根误差: %.5f' % RMSE_lstm)
print('R2: %.5f' % R2_lstm)

http://www.dtcms.com/a/609782.html

相关文章:

  • 最好的营销策划公司做seo网站优化价格
  • 通过Rust高性能异步网络服务器的实现看Rust语言的核心优势
  • 第36节:AI集成与3D场景中的智能NPC
  • 一个基于 LayUI + .NET 开源、轻量的医院住院管理系统
  • StarRocks 4.0:让 Apache Iceberg 数据真正 Query-Ready
  • 网站建设 自己的服务器爬虫python入门
  • android抽屉DrawerLayout在2025的沉浸式兼容
  • 美颜SDK性能优化实战:GPU加速与AI人脸美型的融合开发
  • AndroidStudio历史版本下载
  • Mac抹除重装卡在激活锁?两步快速解锁
  • Java语言是编译型还是解释型| 探究Java的运行机制与性能优化
  • 网站发语音功能如何做广州比较好的网站建设公司
  • 公司网站域名更改怎么做建设行业协会网站发展的建议
  • 【ZeroRange WebRTC】Kinesis Video Streams WebRTC Data Plane WebSocket API 深度解析
  • Docker核心概念、常用命令与实战指南
  • 交换机安全基线整改方式-华为S5700系列
  • Django 接口文档生成:Swagger 与 ReDoc 全面说明
  • Docker K8s VM 简介
  • FPGA教程系列-Vivado中读取ROM中数据
  • 网站怎么添加模块鹿寨建设局网站
  • 响应式外贸网站案例国外ps网站
  • springcloud feign远程调用请求参数对象变成linkhashmap处理
  • “耐达讯自动化Profibus总线光端机在化工变频泵控制系统中的应用与价值解析”
  • centos7.2安装cacti1.2.27
  • 将 vue3 项目打包后部署在 springboot 项目运行
  • 福州短视频seo网站建筑网站首页大图
  • 阿根廷网站后缀毕业设计网站成品
  • 性能相关指标
  • 数据结构--6:优先级队列(堆)
  • ESP32 Wsl2 环境搭建