【Pytorch✨】LSTM01 入门
🪶 一、LSTM 是什么?
LSTM(Long Short-Term Memory)是一种特殊的循环神经网络(RNN),能够记住“过去的信息”并决定“哪些记住,哪些忘掉”。
它非常适合处理以下这种“时间有关”的任务:
- 给一段文字,让模型猜下一个词
- 给一段语音,让模型识别其中内容
- 给一串气温数据,让模型预测明天的温度
🪶 二、为什么叫“长短期记忆”?
这个名字说明了它的最大特点:
词 | 含义 |
---|---|
Long(长期) | 能保留很久以前的有用信息,比如前面的一个关键词 |
Short(短期) | 也能处理最近刚刚输入的信息 |
Memory(记忆) | 就像人脑一样,记住或忘记信息有“策略”和“意图” |
传统的 RNN 容易“忘掉”很久之前的信息,而 LSTM 通过“门控结构”克服了这个问题!
📥 三、LSTM 的输入/输出长什么样?
假设我们用 LSTM 来预测天气(气温),你有:
数据: [21.0, 21.3, 21.8] → 想预测下一个值 22.0
输入 X 是形状为 [batch_size, seq_len, input_size]
的张量
可以理解成:一次给 LSTM 喂多少条序列,每条序列有多少时间步,每个时间步的输入有多少个特征
例如:`[[[21.0], [21.3], [21.8]]]` → `[1, 3, 1]`
名称 | 含义 | 举例 |
---|---|---|
batch_size | 一次送入模型的“样本数量”(多少条数据) | 10 表示一次训练 10 条序列 |
seq_len | 每条序列的“时间步”长度(有几个输入) | 5 表示每条数据是 5 天的气温 |
input_size | 每个时间步包含几个“特征” | 1 表示每步只输入一个数字(如温度) |
输出 y 是下一个值,比如 [[22.0]]
🪶 四、LSTM 的内部结构(过程)
LSTM 的核心是 “三个门 + 一个细胞状态”:
┌─────────────────────────────┐
输入 →───►│ 1. 遗忘门(forget gate) │ ← 过去记忆决定要忘掉多少└─────────────────────────────┘┌─────────────────────────────┐
输入 →───►│ 2. 输入门(input gate) │ ← 新信息能不能写入记忆└─────────────────────────────┘┌─────────────────────────────┐
过去记忆 →│ 3. 输出门(output gate) │──► 输出(隐状态)给下一步└─────────────────────────────┘
小结
门 | 作用 |
---|---|
遗忘门 | 决定“旧记忆要不要保留” |
输入门 | 决定“新输入要不要加入到记忆中” |
输出门 | 决定“当前记忆要不要输出到下一个” |
🪶 五、整体流程图
时间步1 时间步2 时间步3
x₁ ─┬─► LSTM ─► h₁ ─┬─► LSTM ─► h₂ ─┬─► LSTM ─► h₃│ │ │c₁(记忆) c₂(记忆) c₃(记忆)
每个时间步都会:
- 接收一个输入 xₜ
- 接收前一个时间步的隐藏状态 hₜ₋₁ 和记忆状态 cₜ₋₁
- 输出当前的隐藏状态 hₜ 和更新后的记忆 cₜ