第2天:认识LSTM
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: pytorch
(二)具体步骤
1. 什么是LSTM
LSTM(Long Short-Term Memory,长短期记忆网络)是一种特殊的循环神经网络(RNN),专门设计来解决传统RNN在处理长序列时遇到的梯度消失问题。
📖 LSTM的发展背景
传统RNN在处理长序列时面临两个主要问题:
- 梯度消失:随着序列长度增加,早期信息的梯度会急剧衰减
- 梯度爆炸:梯度可能变得过大,导致训练不稳定
LSTM通过引入"门控机制"和"细胞状态"来解决这些问题。
🔧 LSTM的核心组件
LSTM单元包含三个门和一个细胞状态:
1. 遗忘门(Forget Gate)
f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
- 作用:决定从细胞状态中丢弃什么信息
- 输出:0到1之间的值,0表示完全遗忘,1表示完全保留
2. 输入门(Input Gate)
i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C)
- 作用:决定什么新信息被存储在细胞状态中
- 两部分:决定更新什么值 + 创建候选值
3. 输出门(Output Gate)
o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o_t * tanh(C_t)
- 作用:决定输出什么部分的细胞状态
4. 细胞状态(Cell State)
C_t = f_t * C_{t-1} + i_t * C̃_t
- 作用:LSTM的"记忆",信息可以在其中流动
🎯 LSTM的工作流程
让我用一个形象的比喻来解释:
想象LSTM是一个智能的信息管理系统:
- 遗忘门像一个"删除键",决定删除哪些过时信息
- 输入门像一个"筛选器",决定接收哪些新信息
- 细胞状态像一个"主内存",存储重要信息
- 输出门像一个"发布器",决定输出什么信息
📊 LSTM vs 传统RNN对比
特征 | 传统RNN | LSTM |
---|---|---|
记忆能力 | 短期记忆 | 长短期记忆 |
梯度问题 | 梯度消失严重 | 有效缓解 |
参数数量 | 较少 | 较多(约4倍) |
训练复杂度 | 简单 | 复杂 |
长序列处理 | 困难 | 擅长 |
2. 网络结构
import torch
import torch.nn as nn class SimpleLSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): """ 类初始化 :param input_size: 每个时间步的输入特征维度 :param hidden_size: LSTM隐藏状态的维度,也决定了LSTM内部门控单元的大小 :param num_layers: LSTM的层数 :param output_size: 最终输出的维度 """ super(SimpleLSTM, self).__init__() # 定义LSTM层 # 其中batch_first=True:指定输入张量的格式为(batch_size, seq_len, input_size) # 如果不设置,默认格式是(seq_len, batch_size, input_size) self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) # 定义一个线性层,将LSTM输出映射到期望的输出维度 self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): # LSTMn层的前向传播,默认返回output和(hidden, cell_state) # lstm_out:shape(batch_size, seq_len, hidden_size) # hn:最终的隐藏状态,形状为(num_layers, batch_size, hidden_size) # cn:最终的记忆状态,形状为(num_layers, batch_size, hidden_size)与hn相同 lstm_out, (hn, cn) = self.lstm(x) # 取最后一个时间步输出 lstm_out = lstm_out[:, -1, :] # 通过全连接层将LSTM输出映射到输出维度 output = self.fc(lstm_out) return output # 参数设置
input_size = 10 # 输入特征的维度
hidden_size = 20 # LSTM隐藏层的维度
num_layers = 2 # LSTM的层数
output_size = 1 # 输出的维度 # 创建模型实例
model = SimpleLSTM(input_size, hidden_size, num_layers, output_size) # 打印模型结构
print(model) # 示例输入(batch_size, seq_len, input_size)
x = torch.randn(5, 15, input_size) # 本例相当于(5, 15, 10) # 前向传播
output = model(x)
# 计算过程如下:
# 1. 输入:(5, 15, 10)
# 2. LSTM处理:(5, 15, 10) -> (5, 15, 20)
# 3. 取最后的时间步: (5, 15, 20) -> (5, 20)
# 4. 全连接层:(5, 20) -> (5, 1) # 输出结果
print("输入shape为:", x.shape)
print("输出shape为:", output.shape)
(三)总结
LSTM的典型应用
1. 自然语言处理
- 机器翻译
- 情感分析
- 文本生成
2. 时间序列预测
- 股票价格预测
- 天气预报
- 销售预测
3. 语音识别
- 语音到文本转换
- 语音合成
4. 其他序列任务
- 视频分析
- 生物序列分析
- 异常检测