当前位置: 首页 > news >正文

长短期记忆神经网络(LSTM)基础学习与实例:预测序列的未来

目录

1. 前言

2. LSTM的基本原理

2.1 LSTM基本结构

2.2 LSTM的计算过程

3. LSTM实例:预测序列的未来

3.1 数据准备

3.2 模型构建

3.3 模型训练

3.4 模型预测

3.5 完整程序预测序列的未来 

4. 总结


1. 前言

在深度学习领域,循环神经网络(RNN)是处理序列数据的重要工具。然而,传统的RNN在处理长序列时常常会遇到梯度消失或梯度爆炸的问题,导致模型无法有效学习长期依赖关系。为了解决这一问题,长短期记忆神经网络(LSTM)应运而生。LSTM通过引入特殊的结构设计,能够有效地捕获序列数据中的长期依赖关系,因此在自然语言处理、时间序列预测等领域取得了显著的成果。

LSTM在许多序列数据处理任务中表现出色,包括但不限于:

  • 自然语言处理:文本生成、机器翻译、情感分析等。

  • 时间序列预测:股票价格预测、天气预报等。

  • 语音识别:将语音信号转换为文字。

  • 视频分析:动作识别、场景理解等。

如果没有RNN的基础,可以去看这篇博客,

《循环神经网络(RNN)基础入门与实践学习:电影评论情感分类任务》

2. LSTM的基本原理

2.1 LSTM基本结构

传统的RNN在处理长序列时,梯度会随着序列长度的增加而逐渐消失或爆炸。这种现象使得RNN难以学习到序列中的长期依赖关系。例如,在处理一段较长的文本时,RNN可能无法将开头的信息有效传递到结尾,导致模型性能受限。

LSTM通过引入门控机制(gate mechanism)解决了这一问题。门控机制可以控制信息的流动,决定哪些信息应该被保留,哪些信息应该被遗忘,从而有效地捕获长期依赖关系。

LSTM的核心结构包括三个门:

  1. 遗忘门(Forget Gate):决定哪些信息应该被遗忘。

  2. 输入门(Input Gate):决定哪些新信息应该被存储到单元状态中。

  3. 输出门(Output Gate):决定哪些信息应该被输出。

此外,LSTM还有一个单元状态(Cell State),用于在时间步之间传递和存储信息。

在实际中,其整体结构如下:

其中蓝色小球里面存放的就是门控结构的神经元 。

2.2 LSTM的计算过程

这里讲的是上图中的蓝色小球内部结构。

在每个时间步,LSTM根据输入数据和前一时刻的隐藏状态,计算三个门的值,并更新单元状态和隐藏状态。具体计算过程如下:

  1. 遗忘门

    其中,ft​ 是遗忘门的输出,σ 是sigmoid激活函数,Wf​ 和 bf​ 是权重和偏置,ht−1​ 是前一时刻的隐藏状态,xt​ 是当前时刻的输入。

  2. 输入门

    其中,it​ 是输入门的输出,C~t​ 是候选单元状态。

  3. 单元状态更新

    其中,Ct​ 是更新后的单元状态。

  4. 输出门

    其中,ot​ 是输出门的输出,ht​ 是当前时刻的隐藏状态。

为了更方便理解,其结构图如下:

从左往右依次为遗忘门,输入门,输出门。 

3. LSTM实例:预测序列的未来

3.1 数据准备

我们以一个简单的时间序列预测任务为例,预测未来某个时间点的值。首先生成一些模拟数据:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense

# 生成模拟数据
def generate_data(sequence_length=1000):
    x = np.linspace(0, 50, sequence_length)
    y = np.sin(x) + 0.1 * np.random.randn(sequence_length)
    return y

# 准备训练数据
def prepare_data(data, window_size=50):
    X, y = [], []
    for i in range(len(data) - window_size):
        X.append(data[i:i+window_size])
        y.append(data[i+window_size])
    return np.array(X), np.array(y)

# 生成数据
data = generate_data()
window_size = 50
X, y = prepare_data(data, window_size)

# 划分训练集和测试集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

# 数据形状调整
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

3.2 模型构建

使用Keras构建一个简单的LSTM模型:

# 构建LSTM模型
model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(window_size, 1)))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

# 打印模型结构
model.summary()
  • window_size=50:每个样本包含 50 个时间步。

  • 输入数据的形状为 (样本数, 50, 1)

  • 第一个 LSTM 层的输出形状为 (样本数, 50, 50),因为:

    • 每个时间步输出 50 个特征(由 50 个神经元生成)。

    • return_sequences=True,所以输出了所有 50 个时间步的特征。

如果后续还有一个 LSTM 层,则第二个 LSTM 层的输入形状为 (50, 50)(每个时间步有 50 个特征)。

3.3 模型训练

训练模型并记录训练过程:

# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)

# 绘制训练损失和验证损失
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

3.4 模型预测

使用训练好的模型进行预测,并可视化结果:

# 预测
predictions = model.predict(X_test)

# 绘制真实值和预测值
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Values')
plt.plot(predictions, label='Predictions')
plt.title('True Values vs. Predictions')
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.legend()
plt.show()

3.5 完整程序预测序列的未来 

 完整程序如下方便调试:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# 生成模拟数据
def generate_data(sequence_length=1000):
    x = np.linspace(0, 50, sequence_length)
    y = np.sin(x) + 0.1 * np.random.randn(sequence_length)
    return y

# 准备训练数据
def prepare_data(data, window_size=50):
    X, y = [], []
    for i in range(len(data) - window_size):
        X.append(data[i:i+window_size])
        y.append(data[i+window_size])
    return np.array(X), np.array(y)

# 生成数据
data = generate_data()
window_size = 50
X, y = prepare_data(data, window_size)

# 划分训练集和测试集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

# 数据形状调整
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

# 构建LSTM模型
model = Sequential()
model.add(LSTM(60, return_sequences=True, input_shape=(window_size, 1)))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

# 打印模型结构
model.summary()

# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)

# 绘制训练损失和验证损失
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 预测
predictions = model.predict(X_test)

# 绘制真实值和预测值
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Values')
plt.plot(predictions, label='Predictions')
plt.title('True Values vs. Predictions')
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.legend()
plt.show()

4. 总结

长短期记忆神经网络(LSTM)是一种强大的序列建模工具,能够有效地捕获长期依赖关系。通过引入遗忘门、输入门和输出门,LSTM解决了传统RNN在处理长序列时的梯度消失问题。在本文中,我们详细介绍了LSTM的基本原理和结构,并通过一个时间序列预测的实例展示了如何使用Keras实现LSTM模型。

尽管LSTM在许多任务中表现出色,但它也有一些局限性,例如计算复杂度较高、训练时间较长等。随着深度学习技术的发展,许多改进的变体(如GRU、双向LSTM等)也逐渐被提出。在实际应用中,选择合适的模型需要根据具体任务和数据特点进行权衡。我是橙色小博,关注我,一起在人工智能领域学习进步。

http://www.dtcms.com/a/108030.html

相关文章:

  • 外卖平台问题
  • 未来幻想世界
  • JAVA学习小计之IO流01-字节流篇
  • Axure 使用笔记
  • leetcode:3083. 字符串及其反转中是否存在同一子字符串(python3解法)
  • 算法设计与分析之“分治法”
  • Oracle常用高可用方案(10)——RAC
  • MFC BCGControlBar
  • 光谱相机的光谱数据采集原理
  • Python设计模式:代理模式
  • 看行业DeepSeekR1模型如何构建及减少推理大模型过度思考
  • IntelliJ IDEA全栈Git指南:从零构建到高效协作开发
  • 洛谷题单3-P1009 [NOIP 1998 普及组] 阶乘之和-python-流程图重构
  • vue中的 拖拽
  • @ComponentScan注解详解:Spring组件扫描的核心机制
  • 【力扣hot100题】(037)翻转二叉树
  • 每日一题---买卖股票的最好时机(一)、(二)
  • 【每日算法】Day 15-1:哈希表与布隆过滤器——海量数据处理与高效检索的核心技术(C++实现)
  • ollama本地部署大模型(命令行)
  • Eclipse IDE
  • 基本元素定位(findElement方法)
  • 【嵌入式Linux】U-Boot源码分析
  • JMeter接口自动化发包与示例
  • Windows连接服务器Ubuntu_MobaXterm
  • 【Mysql】基础(函数,约束,多表查询,事务)
  • PHP语言基础
  • 深入解析C++类:面向对象编程的核心基石
  • 前端css+html面试题
  • 面向对象分析与设计的多过程多层级实现
  • Generic Mapping Tools(GMT):开源的地球、海洋和行星科学的工具箱、Python与matlab包