RNN与LSTM详解:AI是如何“记住”信息的?
RNN(循环神经网络)的基本原理
RNN是一种处理序列数据的神经网络,其核心在于通过隐藏状态(hidden state)传递历史信息。每个时间步的输入不仅包括当前数据,还包含上一时间步的隐藏状态,形成循环连接。数学表达为:
$$ h_t = \sigma(W_{xh}x_t + W_{hh}h_{t-1} + b_h) $$
其中,$h_t$是当前隐藏状态,$x_t$是输入,$W$为权重矩阵,$b$为偏置,$\sigma$为激活函数(如tanh)。
RNN的局限性
传统RNN存在梯度消失或爆炸问题,难以捕获长距离依赖关系。例如,在文本生成任务中,早期的单词信息可能无法有效传递到后续时间步。
LSTM(长短期记忆网络)的改进
LSTM通过引入门控机制(输入门、遗忘门、输出门)和细胞状态(cell state)解决RNN的缺陷。其核心结构如下:
遗忘门:决定哪些信息从细胞状态中丢弃
$$ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) $$
输入门:更新细胞状态
$$ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \ \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) $$
细胞状态更新
$$ C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t $$
输出门:控制当前隐藏状态输出
$$ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \ h_t = o_t \odot \tanh(C_t) $$
实际应用差异
RNN适用场景:短序列任务(如字符级文本生成),计算资源有限时。
LSTM适用场景:长序列任务(如机器翻译、语音识别),需捕获长期依赖关系。
代码示例(PyTorch实现LSTM单元)
import torch.nn as nn
lstm = nn.LSTM(input_size=100, hidden_size=128, num_layers=2)
input_seq = torch.randn(10, 3, 100) # (seq_len, batch, input_size)
output, (h_n, c_n) = lstm(input_seq)
关键结论
- RNN通过循环连接传递信息,但受限于梯度问题。
- LSTM的门控机制和细胞状态设计显式控制信息流,更适合长期记忆。
- 现代变体(如GRU)在LSTM基础上进一步简化结构,平衡性能与效率。
