当前位置: 首页 > news >正文

第14篇:循环神经网络(RNN)与LSTM:序列建模的利器

摘要
本文系统讲解循环神经网络(RNN)的核心思想、结构与局限,深入解析LSTM(长短期记忆)与GRU的门控机制。结合TensorFlow/Keras,实现文本生成(莎士比亚风格)与时间序列预测(股价)两大实战项目。帮助学习者掌握处理序列数据的“时间机器”。


一、为什么需要RNN?

传统神经网络(如MLP、CNN)假设输入输出独立同分布,但许多任务涉及序列数据

  • 文本:单词有前后顺序
  • 语音:声音信号是时间序列
  • 时间序列:股价、天气随时间变化
  • 机器翻译:输入输出均为序列

RNN通过“循环”结构,使网络具有“记忆”能力,能处理变长序列。


二、RNN基础

2.1 核心思想

RNN在时间步(timestep)上共享参数,并将前一时刻的隐藏状态传递给下一时刻。

2.2 结构与公式

  • 隐藏状态(Hidden State):
    hₜ = tanh(Wₕₕ hₜ₋₁ + Wₓₕ xₜ + bₕ)
    
  • 输出
    yₜ = Wₕᵧ hₜ + bᵧ
    

其中:

  • xₜ:t时刻输入
  • hₜ:t时刻隐藏状态(“记忆”)
  • yₜ:t时刻输出
  • W:权重矩阵

2.3 RNN的“展开”视图

x₁ → h₁ → y₁↑
x₂ → h₂ → y₂↑
x₃ → h₃ → y₃
...

✅ 所有时间步共享同一组权重(Wₕₕ, Wₓₕ, Wₕᵧ)。


三、RNN的局限:梯度消失与梯度爆炸

3.1 问题描述

在反向传播时,梯度通过时间(Backpropagation Through Time, BPTT)传递:

∂L/∂W ∝ ∂hₜ/∂hₜ₋₁ ∂hₜ₋₁/∂hₜ₋₂ ... ∂h₂/∂h₁
  • ||∂hₖ/∂hₖ₋₁|| < 1 → 梯度指数衰减(梯度消失
  • ||∂hₖ/∂hₖ₋₁|| > 1 → 梯度指数爆炸(梯度爆炸

🚫 梯度消失导致RNN难以学习长期依赖(如句子开头与结尾的关系)。


四、LSTM:长短期记忆网络

4.1 核心思想

LSTM通过门控机制(Gating)控制信息的遗忘更新输出,有效缓解梯度消失。

4.2 结构与公式

LSTM单元包含:

  • 细胞状态(Cell State, cₜ):长期记忆通道
  • 隐藏状态(Hidden State, hₜ):短期记忆/输出
三大门:
  1. 遗忘门(Forget Gate):
    fₜ = σ(W_f · [hₜ₋₁, xₜ] + b_f)
    
  2. 输入门(Input Gate):
    iₜ = σ(W_i · [hₜ₋₁, xₜ] + b_i)
    c̃ₜ = tanh(W_c · [hₜ₋₁, xₜ] + b_c)
    
  3. 输出门(Output Gate):
    oₜ = σ(W_o · [hₜ₋₁, xₜ] + b_o)
    
状态更新:
cₜ = fₜ * cₜ₋₁ + iₜ * c̃ₜ
hₜ = oₜ * tanh(cₜ)

cₜ的梯度可近乎恒定传递,解决长期依赖问题。


五、GRU:门控循环单元

5.1 简化版LSTM

GRU将遗忘门和输入门合并为更新门(Update Gate),并减少一个门。

5.2 公式

  • 更新门
    zₜ = σ(W_z · [hₜ₋₁, xₜ])
    
  • 重置门
    rₜ = σ(W_r · [hₜ₋₁, xₜ])
    
  • 候选隐藏状态
    h̃ₜ = tanh(W · [rₜ * hₜ₋₁, xₜ])
    
  • 最终隐藏状态
    hₜ = (1 - zₜ) * hₜ₋₁ + zₜ * h̃ₜ
    

✅ GRU参数更少,训练更快,性能常接近LSTM。


六、实战1:文本生成(莎士比亚风格)

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import requests# 下载莎士比亚文本
url = "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt"
text = requests.get(url).text# 创建字符映射
vocab = sorted(set(text))
char2idx = {c: i for i, c in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])# 创建序列数据集
seq_length = 100
examples_per_epoch = len(text) // (seq_length + 1)char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)def split_input_target(chunk):input_text = chunk[:-1]target_text = chunk[1:]return input_text, target_textdataset = sequences.map(split_input_target)# 批处理与打乱
BATCH_SIZE = 64
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)# 构建LSTM模型
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024model = models.Sequential([layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, seq_length]),layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),layers.Dense(vocab_size)
])# 损失函数
def loss(labels, logits):return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)model.compile(optimizer='adam', loss=loss)# 训练
EPOCHS = 3
history = model.fit(dataset, epochs=EPOCHS)# 生成文本
model.reset_states()
x = np.zeros((1, 1))
x[0, 0] = char2idx['R']  # 以'R'开头generated_text = []
for i in range(300):predictions = model(x)predictions = tf.squeeze(predictions, 0)predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()x[0, 0] = predicted_idgenerated_text.append(idx2char[predicted_id])print(''.join(generated_text))

✅ 输出示例:
“ROMEO: I am a villain to the world, and yet I love her...”


七、实战2:时间序列预测(股价)

import yfinance as yf
import pandas as pd
from sklearn.preprocessing import MinMaxScaler# 获取苹果股价
data = yf.download('AAPL', start='2020-01-01', end='2023-01-01')
prices = data['Close'].values.reshape(-1, 1)# 归一化
scaler = MinMaxScaler()
prices_scaled = scaler.fit_transform(prices)# 创建序列
def create_sequences(data, seq_length):X, y = [], []for i in range(len(data) - seq_length):X.append(data[i:i+seq_length])y.append(data[i+seq_length])return np.array(X), np.array(y)seq_length = 50
X, y = create_sequences(prices_scaled, seq_length)# 划分训练/测试
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]# 构建GRU模型
model = models.Sequential([layers.GRU(50, return_sequences=True, input_shape=(seq_length, 1)),layers.GRU(50),layers.Dense(1)
])model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, epochs=20, batch_size=32, validation_data=(X_test, y_test))# 预测
predictions = model.predict(X_test)
predictions = scaler.inverse_transform(predictions)
y_test_actual = scaler.inverse_transform(y_test)# 可视化
plt.figure(figsize=(12, 6))
plt.plot(y_test_actual, label='真实股价')
plt.plot(predictions, label='预测股价')
plt.legend()
plt.title('GRU股价预测')
plt.show()

✅ GRU能捕捉股价趋势,但难以预测剧烈波动。


八、RNN/LSTM/GRU对比

模型参数量训练速度长期依赖适用场景
RNN❌ 差简单序列
LSTM✅ 强复杂序列(文本、语音)
GRU✅ 较强平衡速度与性能

一般建议:优先尝试GRU,若效果不足再用LSTM。


九、总结与学习建议

本文我们:

  • 理解了RNN的循环结构与BPTT;
  • 掌握了LSTM的三大门控机制;
  • 学习了GRU的简化设计;
  • 实现了文本生成时间序列预测
  • 认识了RNN在序列建模中的强大能力。

📌 学习建议

  1. 理解门控机制:这是LSTM/GRU的核心。
  2. 处理变长序列:使用paddingmasking
  3. 监控梯度:梯度爆炸可用clipnorm
  4. 避免过拟合:使用Dropout、早停。
  5. 考虑现代替代:如Transformer在长序列上更优。

十、下一篇文章预告

第15篇:Transformer与注意力机制:自然语言处理的革命
我们将深入讲解:

  • 注意力机制(Attention)的原理与计算
  • Transformer架构(编码器-解码器)
  • 自注意力(Self-Attention)与多头注意力
  • BERTGPT等预训练模型
  • 使用Hugging Face库进行文本分类与生成

进入NLP的“新纪元”——Transformer与大语言模型的世界!


参考文献

  1. Hochreiter, S. & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation.
  2. Cho, K. et al. (2014). Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation. arXiv.
  3. TensorFlow RNN教程: https://www.tensorflow.org/guide/keras/rnn
  4. 《Deep Learning》 by Goodfellow et al. Chapter 10 (Sequence Modeling)


文章转载自:

http://eToyC7UJ.ggqcg.cn
http://X97YHPjf.ggqcg.cn
http://MOQG5azY.ggqcg.cn
http://40U5bDgP.ggqcg.cn
http://6znXswU6.ggqcg.cn
http://AAwdJwAI.ggqcg.cn
http://XybovsJH.ggqcg.cn
http://F26D1WFJ.ggqcg.cn
http://PTgYuYfS.ggqcg.cn
http://8zqY6P1B.ggqcg.cn
http://rcnrFowr.ggqcg.cn
http://Xjtk1RVA.ggqcg.cn
http://7ADjeZBT.ggqcg.cn
http://LI3dtHGa.ggqcg.cn
http://ZyPXRhM7.ggqcg.cn
http://2SRaUvp0.ggqcg.cn
http://132OZQGn.ggqcg.cn
http://parqCrxn.ggqcg.cn
http://E0AWLtYl.ggqcg.cn
http://baLttopJ.ggqcg.cn
http://IyOsNy16.ggqcg.cn
http://HvNnIDXa.ggqcg.cn
http://E3RCHjdm.ggqcg.cn
http://cS2CO5cE.ggqcg.cn
http://RvUUHupi.ggqcg.cn
http://4jcVgweI.ggqcg.cn
http://7DWY6mks.ggqcg.cn
http://inb7c6Hf.ggqcg.cn
http://fb5TX26S.ggqcg.cn
http://e3uZqHLB.ggqcg.cn
http://www.dtcms.com/a/373841.html

相关文章:

  • 【P02_AI大模型之调用LLM的方式】
  • 浅谈Go 语言开发 AI Agent
  • pgsql for循环一个 数据文本 修改数据 文本如下 ‘40210178‘, ‘40210175‘, ‘40210227‘, ‘40210204‘
  • 工业检测机器视觉为啥非用工业相机?普通相机差在哪?
  • 基于MATLAB的粒子群算法优化广义回归神经网络的实现
  • 25年9月通信基础知识补充1:NTN-TDL信道建模matlab代码(satellite-communications toolbox学习)
  • Aider AI Coding项目 流式处理架构深度分析
  • 打通各大芯片厂商相互间的壁垒,省去繁琐重复的适配流程的智慧工业开源了
  • PAT 1103 Integer Factorization
  • WindowManagerService (WMS)
  • Tool | AI类网址收录
  • SU-03T语音模块的使用
  • kubernetes-lxcfs解决资源可见性问题
  • 235kw发动机飞轮设计说明书CAD+设计说明书
  • Day9 | 类、对象与封装全解析
  • 【財運到】股票期货盯盘助手V3-盯盘界面找不到了
  • “微服务“一词总是出现,它是什么?
  • 打包应用:使用 Electron Forge
  • 详解布隆过滤器
  • ArcGIS学习-16 实战-栅格数据可达性分析
  • MySQL全库检索关键词 - idea 工具 Full-Text Search分享
  • Android小工具:使用python生成适配不同分辨率的dimen文件
  • 基于Python的电影推荐系统【2026最新】
  • 【C语言入门级教学】内存函数
  • 第三届“陇剑杯”CTF比赛部分WP(Web部分和应急)
  • 人工智能-python-深度学习-神经网络VGG(详解)
  • Spring框架重点概述
  • vue2+el的树形穿梭框
  • JuiceFS分布式文件系统
  • 【数据结构】简介