当前位置: 首页 > news >正文

LSTM 与随机森林的对比

LSTM 与随机森林的对比

特性

LSTM

随机森林

适用场景

适合复杂的时间序列数据,能捕捉长时依赖

适合短期预测,适用于较小数据集

计算成本

训练成本较高,需要GPU加速

计算成本较低,易于并行化

可解释性

较低,难以理解内部状态

较高,可解释性强

对噪声的鲁棒性

对噪声敏感,可能需要更大数据量

对噪声较鲁棒,适用于非平稳数据

总结3点

  1. LSTM 适用于长时间依赖的序列预测,适合非线性、时序特征复杂的问题。

  2. 随机森林适用于短期时间序列预测,对于非时间依赖特征有较好的处理能力,计算效率高且易于解释。

  3. 实际应用时可以结合两者,如先用随机森林提取特征,再输入 LSTM 进行预测,提升预测精度。

完整案例

这个案例涉及 LSTM 和随机森林在时间序列预测中的完整流程。

包括:

  • 数据生成(模拟时间序列数据)

  • 数据预处理(特征工程、数据分割)

  • 模型训练(LSTM 和随机森林)

  • 可视化分析

  • 超参数优化(调优策略和优化点)

  • 最终结论(对比 LSTM 和随机森林的预测效果)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

# 1. 生成模拟时间序列数据
def generate_data(n=500):
    np.random.seed(42)
    time = np.arange(n)
    trend = 0.05 * time
    seasonality = 10 * np.sin(time * (2 * np.pi / 50))
    noise = np.random.normal(0, 2, n)
    data = trend + seasonality + noise
    return pd.DataFrame({'time': time, 'value': data})

data = generate_data()

# 2. 数据预处理
scaler = MinMaxScaler()
data['scaled_value'] = scaler.fit_transform(data[['value']])

# 3. 创建特征和标签
def create_features(data, window=10):
    X, y = [], []
    for i in range(len(data) - window):
        X.append(data['scaled_value'].iloc[i:i+window].values)
        y.append(data['scaled_value'].iloc[i+window])
    return np.array(X), np.array(y)

X, y = create_features(data)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)

# 4. 训练随机森林模型
rf = RandomForestRegressor(n_estimators=100)
rf.fit(X_train, y_train)
rf_preds = rf.predict(X_test)

# 5. 训练 LSTM 模型(使用 PyTorch)
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        return self.fc(lstm_out[:, -1, :])

input_size = 1
hidden_size = 50
num_layers = 2
output_size = 1

model = LSTMModel(input_size, hidden_size, num_layers, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

X_train_torch = torch.tensor(X_train, dtype=torch.float32).unsqueeze(-1)
y_train_torch = torch.tensor(y_train, dtype=torch.float32).unsqueeze(-1)
X_test_torch = torch.tensor(X_test, dtype=torch.float32).unsqueeze(-1)

epochs = 20
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train_torch)
    loss = criterion(outputs, y_train_torch)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

model.eval()
lstm_preds = model(X_test_torch).detach().numpy()

# 6. 可视化
plt.figure(figsize=(12, 8))
sns.set_style("darkgrid")

plt.subplot(2, 2, 1)
plt.plot(data['time'], data['value'], color='blue', label='Original Data')
plt.title("Original Time Series Data")
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(y_test, label='True', color='black')
plt.plot(rf_preds, label='RF Prediction', color='red')
plt.title("Random Forest Prediction")
plt.legend()

plt.subplot(2, 2, 3)
plt.plot(y_test, label='True', color='black')
plt.plot(lstm_preds, label='LSTM Prediction', color='green')
plt.title("LSTM Prediction")
plt.legend()

plt.subplot(2, 2, 4)
plt.plot(y_test, label='True', color='black')
plt.plot(rf_preds, label='RF', color='red')
plt.plot(lstm_preds, label='LSTM', color='green')
plt.title("Model Comparison")
plt.legend()

plt.tight_layout()
plt.show()

图片

  1. 随机森林调优:增加 n_estimators=200,限制 max_depth=10 以避免过拟合。

  2. LSTM 模型优化

    • hidden_size=64 提高表达能力。

    • num_layers=3 使模型更深。

    • dropout=0.2 以减少过拟合。

    • epochs=50 确保充分训练,并在每 10 轮打印损失值。

  3. 进一步优化

    • 适当调整 batch_size,提高训练效率。

    • 尝试 GRU 以减少计算开销。

相关文章:

  • stream流常用方法
  • uniapp 滚动尺
  • 【湖南-益阳】《益阳市市本级政府投资信息化项目预算编制与财政评审工作指南》益财评〔2024〕346号-省市费用标准解读系列40
  • 远程计算机无conda情况下配置python虚拟环境
  • Go入门之函数
  • Redis初识
  • 微软宣布 Windows 11 将不再免费升级:升级需趁早
  • Python入门笔记3
  • Mybatis-Plus
  • 数据结构:栈和队列
  • 灵办AI助手Chrome插件全面评测:PC Web端的智能办公利器
  • 学习总结2.14
  • 科普:Docker run的相关事项
  • Redis缓存雪崩、击穿、穿透
  • 第一章 Java面向对象进阶
  • 利用AFE+MCU构建电池管理系统(BMS)
  • 设计模式相关知识点
  • 驱动开发、移植
  • 2025最新智能优化算法:改进型雪雁算法(Improved Snow Geese Algorithm, ISGA)求解23个经典函数测试集,MATLAB
  • MYSQL总结(1)
  • 习近平同瑞典国王卡尔十六世·古斯塔夫就中瑞建交75周年互致贺电
  • 两部门发布外汇领域行刑反向衔接案例,织密金融安全“防护网”
  • 中国驻美国大使馆发言人就中美经贸高层会谈答记者问
  • 债券市场“科技板”来了:哪些机构能尝鲜,重点支持哪些领域
  • 牛市早报|金融政策支持稳市场稳预期发布会将举行,商务部:中方决定同意与美方进行接触
  • 经济日报:落实落细更加积极的财政政策