循环神经网络 - 长短期记忆网络的门控机制
长短期记忆网络(LSTM)的门控机制是其核心设计,用来解决普通 RNN 在长程依赖中遇到的梯度消失与信息混淆问题。为了更进一步理解长短期记忆网络,本文我们来深入分析一下其门控机制。
一、理解长短期记忆网络的“三个门”
所谓门控机制,在数字电路中,门(gate)为一个二值变量 {0, 1},0 代表关闭状态,不 许任何信息通过;1 代表开放状态,允许所有信息通过。大家可以在这个基础上,来理解长短期记忆网络的门控机制。
LSTM 网络引入门控机制(Gating Mechanism)来控制信息传递的路径。前一博文循环神经网络 - 长短期记忆网络-CSDN博客
提到的三个“门”分别为输入门𝒊_𝑡、遗忘门𝒇_𝑡 和输出门𝒐_𝑡。关于这三个门的公式和定义,大家可以返回去复习一下,这里为了辅助大家理解,再次整体列出:
这三个门分别的作用为:
(1) 遗忘门 𝒇_𝑡 控制上一个时刻的内部状态 𝒄_(𝑡−1) 需要遗忘多少信息。
(2) 输入门𝒊_t 控制当前时刻的候选状态𝒄̃ 有多少信息需要保存。
(3) 输出门 𝒐_𝑡 控制当前时刻的内部状态 𝒄𝑡 有多少信息需要输出给外部状态𝒉_𝑡。
当 𝒇_𝑡 = 0, 𝒊_t = 1 时,记忆单元将历史信息清空,并将候选状态向量 写入。但此时记忆单元𝒄_t依然和上一时刻的历史信息相关。当𝒇_𝑡 =1,𝒊_𝑡 =0时,记忆单元将复制上一时刻的内容,不写入新的信息。
需要注意的是,LSTM 网络中的“门”是一种“软”门,取值在 (0, 1) 之间,表示以一定的比例允许信息通过。
二、LSTM 网络的循环单元结构
1、计算过程:
(1)首先利用上一时刻的外部状态 𝒉_(𝑡−1) 和当前时刻的输入 𝒙_t ,计算出三个门,以及候选状态
(2)结合遗忘门 𝒇_𝑡 和输入门 𝒊_𝑡 来更新记忆单元 𝒄_𝑡
(3)结合输出门 𝒐_𝑡,将内部状态的信息传递给外部状态 𝒉_𝑡
可以如下图所表示:
通过LSTM循环单元,整个网络可以建立较长距离的时序依赖关系。整个过程可以抽象总结为:
三、LSTM 网络的记忆
循环神经网络中的隐状态 𝒉 存储了历史信息,可以看作一种记忆(Mem- ory)。在简单循环网络中,隐状态每个时刻都会被重写,因此可以看作一种短期记忆(Short-Term Memory)。在神经网络中,长期记忆(Long-Term Mem- ory)可以看作网络参数,隐含了从训练数据中学到的经验,其更新周期要远远慢于短期记忆。
而在 LSTM 网络中,记忆单元 𝒄 可以在某个时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔。记忆单元 𝒄 中保存信息的生命周期要长于短期记忆 𝒉,但又远远短于长期记忆,因此称为长短期记忆(Long Short-Term Memory)。
这里额外需要注意一点:一般在深度网络参数学习时,参数初始化的值一般都比较小。但是在训 练 LSTM 网络时,过小的值会使得遗忘门的值比较小。这意味着前一时刻的信息大部分都丢失了,这样网络很难捕捉到长距离的依赖信息。并且相邻时间间隔的梯度会非常小,这会导致梯度弥散问题。因此遗忘的参数初始值一般都设得比较大,其偏置向量 𝒃_𝑓 设为 1 或 2。
四、进一步理解长短期记忆网络的门控机制
门控机制主要通过三个“门”来有选择地控制信息的流入、保留和流出:
-
遗忘门(Forget Gate)
遗忘门决定了在当前时刻应当从记忆单元中舍弃哪些信息。它接收前一时刻的隐藏状态 h_{t-1} 与当前输入 x_t,经过一个线性变换后,再通过 sigmoid 激活函数输出一个介于 0 和 1 之间的向量 f_t: -
输入门(Input Gate)
输入门用于控制当前输入带来多少新信息更新进记忆单元。它包括两部分:-
一个 sigmoid 层决定哪些信息需要更新:
-
一个 tanh 层生成新候选记忆内容:
输入门的输出 i_t 表示每个维度上新候选信息
应该更新多少,从而有选择地把当前输入的信息融合进长期记忆中。
-
-
记忆单元更新
记忆单元(也称为细胞状态),表示网络中长期保存的信息。更新时结合了遗忘门和输入门的作用:这里“⊙”表示逐元素相乘。整个表达式表明:上一时刻的记忆经过遗忘门过滤后,与当前候选记忆(按输入门决定比例)相加,构成当前时刻的记忆。
-
输出门(Output Gate)
输出门决定了当前时刻哪些记忆信息最终作为隐藏状态输出给下一时间步或当前任务。具体计算为:然后结合记忆单元经过 tanh\tanhtanh 激活后的信息生成隐藏状态:
这一步就是将过滤后的记忆转化为对外传递的信号,其中 o_t 控制着哪一部分的记忆信息最终展现出来。
如何理解这些机制对信息处理的意义
想象你阅读一篇文章,在开头你记住了一些关键的背景信息(比如主角是谁、情节背景如何)。随着故事推进,有的信息可能不再重要(遗忘门帮助你丢弃不再相关的信息),而新的信息(例如新的事件或人物)会加入你的记忆(通过输入门选择性更新)。最后,当你回顾时,你会聚焦于最重要的细节(通过输出门决定哪些记忆有用并向外部输出)。这种动态的信息管理正是 LSTM 的门控机制在工作,它让网络可以:
-
保留长期重要信息:例如在自然语言中,有时句子开头的信息对理解整个句子至关重要。遗忘门和输入门一起确保这种信息不会轻易遗忘。
-
更新即时信息:当前输入与上下文经过非线性变换后融入记忆,使得模型能够灵活响应新的输入。
-
过滤输出:输出门挑选出最相关的信息供下一步处理,保持整个输出的一致性。
为什么简单的激活函数足以实现复杂的门控?
虽然 LSTM 内部使用的激活函数(如 sigmoid 和 tanh)数学表达相对简单,但它们的核心在于提供非线性变换能力。每一个门都使用 sigmoid 将信息压缩到 [0, 1] 的区间,用于调节信息流;而 tanh 则用于生成候选记忆,使得信息具有充分的正负区分。通过层层叠加这些简单操作,网络能够组合出非常复杂的控制策略,实现对不同信息的精确调控,从而有效地捕捉并利用长程依赖。
例如,在语言生成任务中,门控机制能让模型记住句子前面的背景信息,并在生成后续内容时根据信息的重要性加以调用。通过不断的训练,模型自适应地学习各个权重参数,使得即使最简单的激活函数组合起来,也能充分模拟现实世界中复杂的语义关系和信息动态变化。
五、长短期记忆网络的参数学习
下面给出 LSTM 网络参数学习的具体关键步骤,从前向传播到反向传播再到参数更新的整个流程。LSTM 参数主要包括输入到记忆单元的权重 W_{xh}、隐含状态之间的权重 W_{hh}(包括门控机制中的各个门)、以及输出层的参数 W_{hy} 和相应的偏置项。
1. 前向传播(Forward Pass)
-
初始化
-
将初始隐藏状态 h_0 和记忆单元状态 c_0 通常初始化为零向量。
-
-
依次计算每个时间步的状态
对于每个时间步 t,根据当前输入 x_t 和上一时刻的隐藏状态 h_{t-1},执行以下计算: -
保存中间状态
记录每个时间步的 h_t 和 c_t 以便后续反向传播使用。
2. 损失计算
3. 反向传播通过时间(BPTT)
-
局部梯度计算
-
误差传递与梯度累积
-
更新各参数的梯度
4. 参数更新
-
应用优化算法
在所有时间步梯度累加完成后,根据优化器(如 SGD、Adam 等)的规则,更新所有共享参数:
举例说明
总结
LSTM 网络参数学习的 BPTT 过程可归纳为以下关键步骤:
-
前向传播:按时间顺序计算每个时间步的隐藏状态 hth_t 和记忆单元 ctc_t 并保存中间状态。
-
损失计算:在每个时间步计算输出与目标之间的损失,汇总得到总损失。
-
局部梯度计算:对当前输出误差进行反向传播,计算每个时间步输出层和隐藏层(包括门控输入)的局部梯度。
-
梯度传递与累积:从最后时间步开始,通过链式法则将误差从后向前传递,并累积对各层共享参数的梯度。
-
参数更新:利用累积的梯度,根据优化算法更新各个参数。
通过这一过程,LSTM 不仅能够捕捉短期信息,还能通过记忆单元维持长程依赖,从而使网络在各种序列任务中表现出良好的性能。