NLP学习路线图(二十三):长短期记忆网络(LSTM)
在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为解决这一问题的希望之光。然而,RNN 在实践中遭遇了严峻的挑战——梯度消失与梯度爆炸问题,使其难以有效学习长距离依赖。长短期记忆网络(Long Short-Term Memory, LSTM) 应运而生,成为解决这一瓶颈的革命性方案,极大地推动了序列建模的发展。
一、RNN的困境:梯度消失与长期依赖难题
RNN 的核心思想是引入循环连接,使网络具备一定的记忆能力。其结构可抽象为以下公式:
h_t = f(W_xh * x_t + W_hh * h_{t-1} + b_h)
其中,h_t
是当前时刻的隐藏状态,x_t
是当前输入,h_{t-1}
是前一时刻的隐藏状态,W_xh
和 W_hh
是权重矩阵,b_h
是偏置,f
是激活函数(如 tanh)。
-
短期记忆的有效性: RNN 理论上能利用
h_{t-1}
携带之前所有时刻的信息,对于短序列(如几个词组成的短语)依赖建模效果良好。 -
梯度消失的诅咒: 训练 RNN 通常使用时间反向传播(BPTT) 算法。在计算损失函数对较早时刻参数的梯度时,需要连续乘以多个时刻的
∂h_t / ∂h_{t-1}
(即 Jacobian 矩阵)。当该矩阵的特征值小于 1(常见于 tanh 激活函数),多次连乘会导致梯度指数级衰减至接近零。 -
长期依赖的失效: 梯度消失意味着模型参数几乎无法根据远距离的误差信号进行有效更新。因此,RNN 难以学习序列中相隔较远的元素之间的重要关联(如主谓一致、指代关系)。梯度爆炸问题(特征值大于 1 导致梯度激增)虽也存在,但可通过梯度裁剪等技术缓解,其危害性相对小于梯度消失。
二、LSTM的诞生:引入门控的记忆单元
为了解决 RNN 的根本缺陷,Sepp Hochreiter 和 Jürgen Schmidhuber 在 1997 年提出了 LSTM。其核心创新在于用一个精巧设计的记忆单元(Cell State) 替代 RNN 简单的隐藏状态,并引入三个门控机制(Gates) 来精确调控信息的流动。
-
记忆单元(Cell State -
C_t
):信息的高速公路-
这是 LSTM 的核心,可以将其想象成一条贯穿时间的“传送带”。
-
它的设计目标是在序列处理过程中,相对稳定地传输信息,尤其是那些需要长期保存的信息。
-
信息在
C_t
上的流动主要受三个门控结构的精细调节,而非直接参与非线性变换,这大大缓解了梯度消失问题。
-
-
遗忘门(Forget Gate -
f_t
):选择性遗忘-
功能: 决定记忆单元
C_{t-1}
中哪些信息应该被丢弃或减弱。 -
计算:
f_t = σ(W_f * [h_{t-1}, x_t] + b_f)
-
原理: 接收前一时刻隐藏状态
h_{t-1}
和当前输入x_t
,通过 Sigmoid 激活函数(输出 0 到 1 之间)生成一个遗忘向量f_t
。f_t
中的每个元素对应C_{t-1}
中相应位置的信息保留程度(1 表示完全保留,0 表示完全遗忘)。 -
意义: 这是 LSTM 主动管理信息的第一步,摒弃无用或过时的历史信息,为重要新信息腾出空间。
-
-
输入门(Input Gate -
i_t
)与候选值(Candidate Value -~C_t
):选择性记忆-
输入门功能: 决定当前计算出的新信息(候选值)有多少应该被写入/更新到记忆单元
C_t
中。 -
输入门计算:
i_t = σ(W_i * [h_{t-1}, x_t] + b_i)
-
候选值功能: 根据当前输入和前一状态计算的潜在新信息,可能包含对当前时刻和未来有用的内容。
-
候选值计算:
~C_t = tanh(W_C * [h_{t-1}, x_t] + b_C)
-
更新原理: 输入门
i_t
(控制写入比例)逐元素乘以候选值~C_t
(要写入的内容),得到最终要添加的信息量。
-
-
记忆单元更新
-
计算:
C_t = f_t * C_{t-1} + i_t * ~C_t
-
解释: 这是 LSTM 的核心方程。
-
f_t * C_{t-1}
:遗忘门控制下的旧记忆保留。 -
i_t * ~C_t
:输入门控制下的新信息添加。
-
-
意义: 该操作是线性的(逐元素乘法和加法),使得梯度在
C_t
上流动时能够保持相对稳定,显著缓解了梯度消失问题。信息在这里被选择性地更新,而非完全覆盖。
-
-
输出门(Output Gate -
o_t
):控制输出-
功能: 基于当前的记忆单元
C_t
,决定下一个隐藏状态h_t
应该输出什么(即哪些信息对当前时刻的输出或下一个时刻的计算是重要的)。 -
计算:
-
o_t = σ(W_o * [h_{t-1}, x_t] + b_o)
-
h_t = o_t * tanh(C_t)
-
-
原理: 输出门
o_t
控制经过 tanh 压缩(将值规范到 -1 到 1 之间)后的当前记忆单元C_t
有多少被输出为隐藏状态h_t
。h_t
用于计算当前时刻的输出y_t
(如果任务需要)并传递到下一个 LSTM 单元。
-
LSTM 单元核心公式总结:
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) // 遗忘门
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) // 输入门
o_t = σ(W_o · [h_{t-1}, x_t] + b_o) // 输出门
~C_t = tanh(W_C · [h_{t-1}, x_t] + b_C) // 候选记忆值
C_t = f_t * C_{t-1} + i_t * ~C_t // 更新细胞状态
h_t = o_t * tanh(C_t) // 计算隐藏状态输出
三、LSTM如何解决梯度消失与捕获长期依赖
-
记忆单元 (
C_t
) 的线性流动:-
从
C_t = f_t * C_{t-1} + i_t * ~C_t
可以看出,C_t
对C_{t-1}
的偏导∂C_t / ∂C_{t-1}
主要取决于遗忘门f_t
(和i_t
对C_{t-1}
的间接依赖,但通常较弱)。 -
f_t
是 Sigmoid 的输出(值在 0~1 之间),而不是多个小于 1 的 Jacobian 矩阵连乘。这使得∂C_t / ∂C_{t-1}
的值更有可能保持在合理的范围内(接近 1 或 0),避免了 RNN 中因连续非线性变换导致的指数级衰减。 -
梯度可以通过这条相对“平缓”的
C_t
路径更有效地传播到遥远的过去时刻。
-
-
门控机制:信息流的精细控制
- 遗忘门 (
f_t
): 允许模型显式地“忘记”无关的旧信息(如f_t
接近 0),防止过时信息干扰当前决策。这减少了需要长期维护的信息量,也间接保护了真正重要的长期信息。 - 输入门 (
i_t
): 允许模型有选择地只将相关且重要的新信息(~C_t
)整合到记忆单元中。避免信息过载,保持C_t
的“纯净性”和长期有效性。 - 输出门 (
o_t
): 控制当前记忆单元的内容有多少影响当前输出和传递给下一时刻的隐藏状态h_t
。这使得模型能根据当前任务需要,输出记忆单元中不同部分的信息。 - 门控的动态性: 所有门(
f_t
,i_t
,o_t
)的值都是根据当前输入x_t
和前一刻状态h_{t-1}
动态计算的。这意味着 LSTM 在每个时间步都能根据具体上下文,自适应地决定记住什么、忘记什么、输出什么,具有强大的情境适应能力。
四、LSTM的变体与拓展
基础 LSTM 结构取得了巨大成功,研究者们提出了多种改进和变体:
-
门控循环单元(GRU - Gated Recurrent Unit):
-
由 Cho 等人于 2014 年提出,旨在简化 LSTM 结构,提高计算效率。
-
核心简化: 将遗忘门和输入门合并为一个单一的更新门(Update Gate -
z_t
),并合并了细胞状态C_t
和隐藏状态h_t
。 -
核心公式:
z_t = σ(W_z · [h_{t-1}, x_t]) // 更新门 r_t = σ(W_r · [h_{t-1}, x_t]) // 重置门 ~h_t = tanh(W · [r_t * h_{t-1}, x_t]) // 候选激活值 h_t = (1 - z_t) * h_{t-1} + z_t * ~h_t // 更新隐藏状态
-
特点: 参数更少,计算速度通常更快;在许多任务上与 LSTM 性能相当,有时甚至更好;理解和使用相对简单。
-
-
双向LSTM(BiLSTM - Bidirectional LSTM):
-
动机: 标准 LSTM 只利用了过去(左边)的上下文信息。许多任务(如词性标注、命名实体识别)中,未来(右边)的上下文信息同样至关重要。
-
结构: 包含两个独立的 LSTM 层:
-
前向层(Forward Layer): 按时间顺序(
t=1 -> T
)处理序列,捕获过去上下文。 -
后向层(Backward Layer): 按时间逆序(
t=T -> 1
)处理序列,捕获未来上下文。
-
-
输出组合: 通常将每个时刻
t
的前向隐藏状态h_t^f
和后向隐藏状态h_t^b
拼接(Concatenate) 起来形成最终的输出表示[h_t^f; h_t^b]
。这个表示融合了整个序列在时刻t
的上下文信息。 -
应用: 在需要全局上下文理解的 NLP 任务(如序列标注、文本分类、情感分析)中表现出色。
-
-
深层LSTM(Deep LSTM):
-
动机: 单层 RNN/LSTM 的表征能力有限。通过堆叠多层 LSTM 单元可以学习更抽象、更复杂的特征表示。
-
结构: 将前一层的隐藏状态
h_t^{(l-1)}
作为下一层的输入x_t^{(l)}
。h_t^{(1)} = LSTM_1(x_t, h_{t-1}^{(1)}) h_t^{(2)} = LSTM_2(h_t^{(1)}, h_{t-1}^{(2)}) ... h_t^{(L)} = LSTM_L(h_t^{(L-1)}, h_{t-1}^{(L)})
-
特点: 高层 LSTM 处理的是低层 LSTM 产生的抽象表示,能够捕捉更复杂的模式和长距离依赖。但也增加了模型复杂度和训练难度(需要更多数据和更谨慎的参数初始化/正则化)。
-
五、LSTM在NLP中的经典应用
LSTM 及其变体在深度学习时代的 NLP 发展中扮演了核心角色:
-
文本分类(Text Classification):
-
任务: 判断整段文本的类别(如新闻主题分类、情感分析[正面/负面]、垃圾邮件检测)。
-
模型: 通常使用单向或双向 LSTM/GRU 处理输入的词序列。取最后一个时刻的隐藏状态
h_T
或对所有时刻隐藏状态进行平均/最大池化(Pooling) 得到的向量作为整个文本的表示。将此表示输入到一个全连接层进行分类。 -
优势: 能有效捕捉文本中的词序信息和上下文语义,比简单的词袋模型强大得多。
-
-
序列标注(Sequence Labeling):
-
任务: 为输入序列中的每一个单元(token)分配一个标签(如词性标注[名词、动词等]、命名实体识别[人名、地名、机构名]、中文分词)。
-
模型: 双向 LSTM (BiLSTM) 是此任务的标配。每个时刻
t
的输出[h_t^f; h_t^b]
融合了该词左右两侧的完整上下文信息,这对于准确判断当前词的标签至关重要。BiLSTM 的输出层通常连接一个 Softmax 分类器 为每个 token 独立预测标签。更高级的模型会在 BiLSTM 之上添加 条件随机场(CRF) 层(即 BiLSTM-CRF 模型),考虑标签之间的转移约束(如 I-PER 不能跟在 O 后面),进一步提升效果。
-
-
机器翻译(Machine Translation - MT):
-
Encoder-Decoder 架构(Seq2Seq): LSTM 是早期神经机器翻译(NMT)模型的支柱。
-
编码器(Encoder): 通常是一个(多层)BiLSTM,将源语言句子编码成一个固定长度的上下文向量(Context Vector)(通常取最后时刻的隐藏状态)。
-
解码器(Decoder): 通常是一个(多层)单向 LSTM。它以编码器产生的上下文向量作为初始状态,并自回归(Autoregressive) 地生成目标语言序列,即每一步以上一步生成的词作为输入(或结合注意力机制)。
-
-
注意力机制(Attention Mechanism): 标准的 Seq2Seq 模型瓶颈在于依赖单一的上下文向量。注意力机制允许解码器在生成每一个目标词时,“动态地关注”编码器输出的不同部分(即源句子中不同词或短语的表示)。这极大地提高了模型,尤其是 LSTM 模型处理长句子和捕捉对齐关系的能力,是 NMT 性能飞跃的关键。计算 Attention 分数和上下文向量是核心步骤。
-
-
文本生成(Text Generation):
-
任务: 根据给定的输入(可能为空、一个前缀或一个主题)生成连贯、流畅、符合语法和语义的文本(如对话生成、诗歌创作、故事续写)。
-
模型: 通常基于 LSTM 的自回归语言模型。
-
训练时:输入一段文本序列(
x_1, x_2, ..., x_T
),模型学习预测下一个词的概率分布P(x_{t+1} | x_1, ..., x_t)
。 -
生成时:给定起始词或提示(Prompt),模型根据学习到的分布,递归地采样出下一个词(
x_{t+1}
),并将其作为下一时刻的输入,如此循环直至生成结束标记或达到长度限制。
-
-
优势: LSTM 的记忆能力使其能够维持生成文本的连贯性和主题一致性,比基于 N-gram 的模型能生成更长的、更自然的文本。
-
六、LSTM的代码实现示例(概念性伪代码)
import torch
import torch.nn as nnclass LSTMCell(nn.Module):def __init__(self, input_size, hidden_size):super(LSTMCell, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_size# 定义权重参数: 将输入和上一时刻隐藏状态映射到遗忘门、输入门、输出门、候选值self.W_f = nn.Linear(input_size + hidden_size, hidden_size) # 遗忘门权重self.W_i = nn.Linear(input_size + hidden_size, hidden_size) # 输入门权重self.W_c = nn.Linear(input_size + hidden_size, hidden_size) # 候选值权重self.W_o = nn.Linear(input_size + hidden_size, hidden_size) # 输出门权重def forward(self, x_t, state):# state 是一个元组 (h_{t-1}, C_{t-1})h_prev, C_prev = state# 拼接当前输入x_t和前一时序隐藏状态h_prevcombined = torch.cat((x_t, h_prev), dim=1)# 计算各门控和候选值 (应用Sigmoid/tanh激活)f_t = torch.sigmoid(self.W_f(combined)) # 遗忘门i_t = torch.sigmoid(self.W_i(combined)) # 输入门o_t = torch.sigmoid(self.W_o(combined)) # 输出门C_tilde = torch.tanh(self.W_c(combined)) # 候选记忆值# 更新细胞状态: C_t = f_t * C_prev + i_t * C_tildeC_t = f_t * C_prev + i_t * C_tilde# 计算当前隐藏状态: h_t = o_t * tanh(C_t)h_t = o_t * torch.tanh(C_t)# 返回当前隐藏状态和更新后的细胞状态 (作为下一时刻的输入状态)return h_t, C_t# 使用示例 (概念性)
input_size = 10 # 输入特征维度
hidden_size = 20 # LSTM隐藏状态维度
lstm_cell = LSTMCell(input_size, hidden_size)# 初始化隐藏状态和细胞状态 (通常初始化为零)
h0 = torch.zeros(1, hidden_size) # (batch_size, hidden_size)
C0 = torch.zeros(1, hidden_size)# 处理一个时间步的输入
x_t = torch.randn(1, input_size) # 当前时刻输入 (batch_size, input_size)
h_t, C_t = lstm_cell(x_t, (h0, C0))# h_t 可作为输出或传递给下一个时间步