LSTM的一个计算例子
要理解LSTM(长短期记忆网络)的门机制,我们可以通过一个具体的"句子代词指代"任务来演示。这个任务需要模型记住句子中早期出现的名词(如"Alice"),并在后续出现代词(如"She")时正确关联,非常适合展示LSTM如何通过门机制处理长期依赖。
任务场景
我们要处理的句子是:"Alice is happy. She smiles."
目标:让模型在看到"She"时,能记住前面的"Alice"是女性,从而正确理解"She"指代"Alice"。
LSTM核心结构回顾
LSTM通过三个门控制信息的流动,核心是细胞状态(cell state)(类似"长期记忆")和隐藏状态(hidden state)(类似"短期输出"):
- 遗忘门(Forget Gate):决定从细胞状态中丢弃哪些旧信息(0=完全遗忘,1=完全保留)
- 输入门(Input Gate):决定哪些新信息被存入细胞状态
- 输出门(Output Gate):决定从细胞状态中输出哪些信息作为当前隐藏状态
我将完善LSTM计算示例,补充具体的参数初始值和向量定义,让整个计算过程更清晰。
1. 向量定义
- 输入向量维度:3(分别表示"女性特征"、“动词特征”、“代词特征”)
- 隐藏状态维度:1(简化计算,实际中可更大)
[hₜ₋₁, xₜ]
表示将隐藏状态与输入向量拼接,维度为1+3=4
2. 初始参数(实际训练中学习得到,此处为演示设定)
- 遗忘门权重 Wf = [0.2, 0.3, -0.1, 0.4],偏置 bf = -0.1
- 输入门权重 Wi = [0.5, 0.2, 0.3, -0.2],偏置 bi = 0.3
- 输出门权重 Wo = [0.1, -0.2, 0.4, 0.3],偏置 bo = -0.2
- 候选细胞状态权重 Wc = [0.3, 0.1, -0.2, 0.5],偏置 bc = 0.1
- 初始隐藏状态 h₋₁ = [0]
- 初始细胞状态 C₋₁ = [0]
3. 输入编码
- “Alice” → x₀ = [1, 0, 0](女性名词)
- “is” → x₁ = [0, 1, 0](系动词)
- “happy” → x₂ = [0, 0, 0](形容词,无特殊标记)
- “She” → x₃ = [1, 0, 1](女性代词)
- “smiles” → x₄ = [0, 1, 0](动词)
完整计算过程
时间步t=0:输入"Alice"
-
拼接向量:
[h₋₁, x₀] = [0, 1, 0, 0]
-
遗忘门计算:
f₀ = sigmoid(Wf·[h₋₁, x₀] + bf) = sigmoid(0.2×0 + 0.3×1 + (-0.1)×0 + 0.4×0 + (-0.1)) = sigmoid(0.3 - 0.1) = sigmoid(0.2) ≈ 0.55
(对初始空记忆轻度保留)
-
输入门计算:
i₀ = sigmoid(Wi·[h₋₁, x₀] + bi) = sigmoid(0.5×0 + 0.2×1 + 0.3×0 + (-0.2)×0 + 0.3) = sigmoid(0.2 + 0.3) = sigmoid(0.5) ≈ 0.62Ĉ₀ = tanh(Wc·[h₋₁, x₀] + bc) = tanh(0.3×0 + 0.1×1 + (-0.2)×0 + 0.5×0 + 0.1) = tanh(0.1 + 0.1) = tanh(0.2) ≈ 0.197
-
细胞状态更新:
C₀ = f₀×C₋₁ + i₀×Ĉ₀ = 0.55×0 + 0.62×0.197 ≈ 0.122
(细胞状态开始记录"Alice是女性"的信息)
-
输出门计算:
o₀ = sigmoid(Wo·[h₋₁, x₀] + bo) = sigmoid(0.1×0 + (-0.2)×1 + 0.4×0 + 0.3×0 + (-0.2)) = sigmoid(-0.2 - 0.2) = sigmoid(-0.4) ≈ 0.401h₀ = o₀×tanh(C₀) = 0.401×tanh(0.122) ≈ 0.401×0.121 ≈ 0.048
时间步t=1:输入"is"
-
拼接向量:
[h₀, x₁] = [0.048, 0, 1, 0]
-
遗忘门计算:
f₁ = sigmoid(Wf·[h₀, x₁] + bf) = sigmoid(0.2×0.048 + 0.3×0 + (-0.1)×1 + 0.4×0 + (-0.1)) = sigmoid(0.0096 - 0.1 - 0.1) = sigmoid(-0.1904) ≈ 0.452
(对已有记忆轻度遗忘)
-
输入门计算:
i₁ = sigmoid(Wi·[h₀, x₁] + bi) = sigmoid(0.5×0.048 + 0.2×0 + 0.3×1 + (-0.2)×0 + 0.3) = sigmoid(0.024 + 0.3 + 0.3) = sigmoid(0.624) ≈ 0.651Ĉ₁ = tanh(Wc·[h₀, x₁] + bc) = tanh(0.3×0.048 + 0.1×0 + (-0.2)×1 + 0.5×0 + 0.1) = tanh(0.0144 - 0.2 + 0.1) = tanh(-0.0856) ≈ -0.085
-
细胞状态更新:
C₁ = f₁×C₀ + i₁×Ĉ₁ = 0.452×0.122 + 0.651×(-0.085) ≈ 0.055 - 0.055 ≈ 0.000
("is"作为虚词,几乎不改变细胞状态)
-
输出门计算:
o₁ = sigmoid(Wo·[h₀, x₁] + bo) = sigmoid(0.1×0.048 + (-0.2)×0 + 0.4×1 + 0.3×0 + (-0.2)) = sigmoid(0.0048 + 0.4 - 0.2) = sigmoid(0.2048) ≈ 0.551h₁ = o₁×tanh(C₁) ≈ 0.551×0 ≈ 0.000
时间步t=3:输入"She"(省略t=2"happy"步骤,C₂≈0.012,h₂≈0.005)
-
拼接向量:
[h₂, x₃] = [0.005, 1, 0, 1]
-
遗忘门计算:
f₃ = sigmoid(Wf·[h₂, x₃] + bf) = sigmoid(0.2×0.005 + 0.3×1 + (-0.1)×0 + 0.4×1 + (-0.1)) = sigmoid(0.001 + 0.3 + 0.4 - 0.1) = sigmoid(0.601) ≈ 0.646
(保留之前的记忆)
-
输入门计算:
i₃ = sigmoid(Wi·[h₂, x₃] + bi) = sigmoid(0.5×0.005 + 0.2×1 + 0.3×0 + (-0.2)×1 + 0.3) = sigmoid(0.0025 + 0.2 - 0.2 + 0.3) = sigmoid(0.3025) ≈ 0.575Ĉ₃ = tanh(Wc·[h₂, x₃] + bc) = tanh(0.3×0.005 + 0.1×1 + (-0.2)×0 + 0.5×1 + 0.1) = tanh(0.0015 + 0.1 + 0.5 + 0.1) = tanh(0.7015) ≈ 0.605
-
细胞状态更新:
C₃ = f₃×C₂ + i₃×Ĉ₃ = 0.646×0.012 + 0.575×0.605 ≈ 0.008 + 0.348 ≈ 0.356
(细胞状态显著更新,强化了女性特征记忆)
-
输出门计算:
o₃ = sigmoid(Wo·[h₂, x₃] + bo) = sigmoid(0.1×0.005 + (-0.2)×1 + 0.4×0 + 0.3×1 + (-0.2)) = sigmoid(0.0005 - 0.2 + 0.3 - 0.2) = sigmoid(-0.0995) ≈ 0.475h₃ = o₃×tanh(C₃) = 0.475×tanh(0.356) ≈ 0.475×0.345 ≈ 0.164
(输出了较强的信号,表明模型成功将"She"与"Alice"关联)
门机制的关键作用总结
-
遗忘门:选择性遗忘无关信息,保留重要的上下文(如保留"Alice"的女性特征)
-
输入门:只将有价值的新信息存入细胞状态(如"Alice"作为主体被重点存储,而"is"这样的虚词几乎不被记住)
-
输出门:根据当前输入动态调整输出强度(当遇到"She"时,输出与"Alice"相关的记忆信息)
通过这三种门的协同作用,LSTM能够有效解决传统RNN的长期依赖问题,在处理长序列时保持关键信息的记忆能力。