第14篇:循环神经网络(RNN)与LSTM:序列建模的利器
摘要:
本文系统讲解循环神经网络(RNN)的核心思想、结构与局限,深入解析LSTM(长短期记忆)与GRU的门控机制。结合TensorFlow/Keras,实现文本生成(莎士比亚风格)与时间序列预测(股价)两大实战项目。帮助学习者掌握处理序列数据的“时间机器”。
一、为什么需要RNN?
传统神经网络(如MLP、CNN)假设输入输出独立同分布,但许多任务涉及序列数据:
- ✅ 文本:单词有前后顺序
- ✅ 语音:声音信号是时间序列
- ✅ 时间序列:股价、天气随时间变化
- ✅ 机器翻译:输入输出均为序列
✅ RNN通过“循环”结构,使网络具有“记忆”能力,能处理变长序列。
二、RNN基础
2.1 核心思想
RNN在时间步(timestep)上共享参数,并将前一时刻的隐藏状态传递给下一时刻。
2.2 结构与公式
- 隐藏状态(Hidden State):
hₜ = tanh(Wₕₕ hₜ₋₁ + Wₓₕ xₜ + bₕ)
- 输出:
yₜ = Wₕᵧ hₜ + bᵧ
其中:
xₜ
:t时刻输入hₜ
:t时刻隐藏状态(“记忆”)yₜ
:t时刻输出W
:权重矩阵
2.3 RNN的“展开”视图
x₁ → h₁ → y₁↑
x₂ → h₂ → y₂↑
x₃ → h₃ → y₃
...
✅ 所有时间步共享同一组权重(
Wₕₕ
,Wₓₕ
,Wₕᵧ
)。
三、RNN的局限:梯度消失与梯度爆炸
3.1 问题描述
在反向传播时,梯度通过时间(Backpropagation Through Time, BPTT)传递:
∂L/∂W ∝ ∂hₜ/∂hₜ₋₁ ∂hₜ₋₁/∂hₜ₋₂ ... ∂h₂/∂h₁
- 若
||∂hₖ/∂hₖ₋₁|| < 1
→ 梯度指数衰减(梯度消失) - 若
||∂hₖ/∂hₖ₋₁|| > 1
→ 梯度指数爆炸(梯度爆炸)
🚫 梯度消失导致RNN难以学习长期依赖(如句子开头与结尾的关系)。
四、LSTM:长短期记忆网络
4.1 核心思想
LSTM通过门控机制(Gating)控制信息的遗忘、更新与输出,有效缓解梯度消失。
4.2 结构与公式
LSTM单元包含:
- 细胞状态(Cell State,
cₜ
):长期记忆通道 - 隐藏状态(Hidden State,
hₜ
):短期记忆/输出
三大门:
- 遗忘门(Forget Gate):
fₜ = σ(W_f · [hₜ₋₁, xₜ] + b_f)
- 输入门(Input Gate):
iₜ = σ(W_i · [hₜ₋₁, xₜ] + b_i) c̃ₜ = tanh(W_c · [hₜ₋₁, xₜ] + b_c)
- 输出门(Output Gate):
oₜ = σ(W_o · [hₜ₋₁, xₜ] + b_o)
状态更新:
cₜ = fₜ * cₜ₋₁ + iₜ * c̃ₜ
hₜ = oₜ * tanh(cₜ)
✅
cₜ
的梯度可近乎恒定传递,解决长期依赖问题。
五、GRU:门控循环单元
5.1 简化版LSTM
GRU将遗忘门和输入门合并为更新门(Update Gate),并减少一个门。
5.2 公式
- 更新门:
zₜ = σ(W_z · [hₜ₋₁, xₜ])
- 重置门:
rₜ = σ(W_r · [hₜ₋₁, xₜ])
- 候选隐藏状态:
h̃ₜ = tanh(W · [rₜ * hₜ₋₁, xₜ])
- 最终隐藏状态:
hₜ = (1 - zₜ) * hₜ₋₁ + zₜ * h̃ₜ
✅ GRU参数更少,训练更快,性能常接近LSTM。
六、实战1:文本生成(莎士比亚风格)
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import requests# 下载莎士比亚文本
url = "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt"
text = requests.get(url).text# 创建字符映射
vocab = sorted(set(text))
char2idx = {c: i for i, c in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])# 创建序列数据集
seq_length = 100
examples_per_epoch = len(text) // (seq_length + 1)char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)def split_input_target(chunk):input_text = chunk[:-1]target_text = chunk[1:]return input_text, target_textdataset = sequences.map(split_input_target)# 批处理与打乱
BATCH_SIZE = 64
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)# 构建LSTM模型
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024model = models.Sequential([layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, seq_length]),layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),layers.Dense(vocab_size)
])# 损失函数
def loss(labels, logits):return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)model.compile(optimizer='adam', loss=loss)# 训练
EPOCHS = 3
history = model.fit(dataset, epochs=EPOCHS)# 生成文本
model.reset_states()
x = np.zeros((1, 1))
x[0, 0] = char2idx['R'] # 以'R'开头generated_text = []
for i in range(300):predictions = model(x)predictions = tf.squeeze(predictions, 0)predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()x[0, 0] = predicted_idgenerated_text.append(idx2char[predicted_id])print(''.join(generated_text))
✅ 输出示例:
“ROMEO: I am a villain to the world, and yet I love her...”
七、实战2:时间序列预测(股价)
import yfinance as yf
import pandas as pd
from sklearn.preprocessing import MinMaxScaler# 获取苹果股价
data = yf.download('AAPL', start='2020-01-01', end='2023-01-01')
prices = data['Close'].values.reshape(-1, 1)# 归一化
scaler = MinMaxScaler()
prices_scaled = scaler.fit_transform(prices)# 创建序列
def create_sequences(data, seq_length):X, y = [], []for i in range(len(data) - seq_length):X.append(data[i:i+seq_length])y.append(data[i+seq_length])return np.array(X), np.array(y)seq_length = 50
X, y = create_sequences(prices_scaled, seq_length)# 划分训练/测试
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]# 构建GRU模型
model = models.Sequential([layers.GRU(50, return_sequences=True, input_shape=(seq_length, 1)),layers.GRU(50),layers.Dense(1)
])model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, epochs=20, batch_size=32, validation_data=(X_test, y_test))# 预测
predictions = model.predict(X_test)
predictions = scaler.inverse_transform(predictions)
y_test_actual = scaler.inverse_transform(y_test)# 可视化
plt.figure(figsize=(12, 6))
plt.plot(y_test_actual, label='真实股价')
plt.plot(predictions, label='预测股价')
plt.legend()
plt.title('GRU股价预测')
plt.show()
✅ GRU能捕捉股价趋势,但难以预测剧烈波动。
八、RNN/LSTM/GRU对比
模型 | 参数量 | 训练速度 | 长期依赖 | 适用场景 |
---|---|---|---|---|
RNN | 少 | 快 | ❌ 差 | 简单序列 |
LSTM | 多 | 慢 | ✅ 强 | 复杂序列(文本、语音) |
GRU | 中 | 中 | ✅ 较强 | 平衡速度与性能 |
✅ 一般建议:优先尝试GRU,若效果不足再用LSTM。
九、总结与学习建议
本文我们:
- 理解了RNN的循环结构与BPTT;
- 掌握了LSTM的三大门控机制;
- 学习了GRU的简化设计;
- 实现了文本生成与时间序列预测;
- 认识了RNN在序列建模中的强大能力。
📌 学习建议:
- 理解门控机制:这是LSTM/GRU的核心。
- 处理变长序列:使用
padding
与masking
。- 监控梯度:梯度爆炸可用
clipnorm
。- 避免过拟合:使用Dropout、早停。
- 考虑现代替代:如Transformer在长序列上更优。
十、下一篇文章预告
第15篇:Transformer与注意力机制:自然语言处理的革命
我们将深入讲解:
- 注意力机制(Attention)的原理与计算
- Transformer架构(编码器-解码器)
- 自注意力(Self-Attention)与多头注意力
- BERT、GPT等预训练模型
- 使用Hugging Face库进行文本分类与生成
进入NLP的“新纪元”——Transformer与大语言模型的世界!
参考文献
- Hochreiter, S. & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation.
- Cho, K. et al. (2014). Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation. arXiv.
- TensorFlow RNN教程: https://www.tensorflow.org/guide/keras/rnn
- 《Deep Learning》 by Goodfellow et al. Chapter 10 (Sequence Modeling)