当前位置: 首页 > news >正文

深度学习基本模块:LSTM 长短期记忆网络

长短期记忆网络(LSTM)是一种特殊的循环神经网络(RNN),专门设计用于解决标准RNN在处理长序列数据时遇到的梯度消失和长期依赖问题:深度学习基本模块:RNN 循环神经网络

LSTM通过引入门控机制和细胞状态,使网络能够选择性地记住或忘记信息,从而更有效地处理长序列数据

一、LSTM介绍

1.1 结构

  • 输入层:序列数据,形状为(batch_size, seq_len, input_size)的张量(与RNN相同)

  • LSTM层

    • 核心组件

      • 细胞状态CtC_tCt,形状:(batch_size, hidden_size),贯穿整个时间步的"记忆通道",负责长期信息的传递。
      • 隐藏状态hth_tht,形状:(batch_size, hidden_size),作为当前时间步的输出,并传递到下一个时间步。
      • 门控机制:控制信息的流动,包括:
        • 遗忘门:决定从细胞状态中丢弃哪些信息
        • 输入门:决定哪些新信息存入细胞状态
        • 输出门:决定输出哪些信息
    • 可学习参数

      • 权重矩阵WfW_fWf, WiW_iWi, WcW_cWc, WoW_oWo(每个门和候选细胞状态各有一个权重矩阵),形状均为(hidden_size, input_size + hidden_size)
      • 偏置项bfb_fbf, bib_ibi, bcb_cbc, bob_obo,形状均为(hidden_size,)
    • 可学习参数(PyTorch实现)在PyTorch中,为了计算效率,这些参数被组合成更大的矩阵,但数学计算是等价的

      • 输入到隐藏的权重​:weightihl0weight_ih_l0weightihl0,形状为(4 * hidden_size, input_size)
      • 隐藏到隐藏的权重​:weighthhl0weight_hh_l0weighthhl0,形状为(4 * hidden_size, hidden_size)
      • 输入到隐藏的偏置​:biasihl0bias_ih_l0biasihl0,形状为(4 * hidden_size)
      • 隐藏到隐藏的偏置​:biashhl0bias_hh_l0biashhl0,形状为(4 * hidden_size)
  • 激活函数

    • 门控单元(遗忘门、输入门、输出门):使用Sigmoid激活函数σ\sigmaσ;输出范围:[0, 1];模拟"开关"机制,控制信息流通的比例(0表示完全关闭,1表示完全打开)
    • 候选细胞状态:使用Tanh激活函数;输出范围:[-1, 1];生成新的候选值,添加到细胞状态中
    • 输出隐藏状态:使用Tanh激活函数;输出范围:[-1, 1];作用:将细胞状态的值压缩到合理范围,然后通过输出门控制输出比例
  • 门控机制的意义

    • 遗忘门:像是一个"过滤器",决定保留多少旧信息
      • 值接近1:保留大部分信息
      • 值接近0:丢弃大部分信息
    • 输入门:像是一个"写入开关",决定添加多少新信息
      • 值接近1:添加大量新信息
      • 值接近0:添加少量新信息
    • 输出门:像是一个"读取开关",决定输出多少信息
      • 值接近1:输出大量信息
      • 值接近0:输出少量信息

这种设计使LSTM能够有效解决长期依赖问题,在长序列任务中表现出色。细胞状态CtC_tCt作为"记忆通道",可以在多个时间步中保持信息不变,而隐藏状态hth_tht则作为每个时间步的"输出表达"。

1.2 参数

  • input_size:每个时间步输入的特征数量。对于音频频谱,通常是频率维度(如梅尔频带数);对于文本处理,通常是词向量的维度。这个参数决定了输入层的大小。
  • hidden_size:隐藏状态hth_tht和细胞状态CtC_tCt的维度,决定了LSTM的记忆容量和表征能力。较大的hidden_size可以存储更多信息但会增加计算量和过拟合风险,需要根据任务复杂度平衡选择。
  • num_layers:堆叠的LSTM层数。增加层数可以提高模型的抽象能力和复杂度(深层网络可以提取更高级的特征),但也会增加计算时间和过拟合风险。通常1-3层即可满足大多数序列建模任务。
  • bias:是否使用偏置项,默认为True
  • batch_first:输入输出维度顺序,默认为False。True: (batch, seq, feature)False: (seq, batch, feature)
  • dropout:在多层LSTM中(num_layers > 1)应用dropout防止过拟合的概率。默认值:0(表示不使用 dropout)。
  • bidirectional:是否使用双向LSTM。当设置为True时,网络会同时处理前向和后向序列信息,捕获更丰富的上下文特征,但计算量会翻倍且需要更多内存。输出特征维度变为hidden_size * 2。

1.3 输入输出维度

  • 输入数据维度(batch_size, seq_len, input_size)(当batch_first=True时)
  • 输出序列维度(batch_size, seq_len, hidden_size * num_directions)
  • 最终隐藏状态(num_layers * num_directions, batch_size, hidden_size)
  • 最终细胞状态(num_layers * num_directions, batch_size, hidden_size)

以下是 LSTM 计算过程的详细说明,包括每个步骤的公式和符号的注解,以帮助理解 LSTM 的工作原理。

以下是 LSTM 计算过程的详细说明,使用 Markdown 格式,并按照您的要求将公式和符号进行格式化。

1.4 计算过程

  1. 计算遗忘门
    ft=σ(Wf⋅[ht−1,xt]+bf)f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)ft=σ(Wf[ht1,xt]+bf)

    • ftf_tft:遗忘门的输出,决定从细胞状态中丢弃哪些信息,范围在 (0,1)(0, 1)(0,1) 之间。
    • σ\sigmaσ:Sigmoid 激活函数,输出值在 0 到 1 之间,控制信息的通过比例。
    • WfW_fWf:遗忘门的权重矩阵,形状为 (hidden_size,hidden_size+input_size)(hidden\_size, hidden\_size + input\_size)(hidden_size,hidden_size+input_size)
    • [ht−1,xt][h_{t-1}, x_t][ht1,xt]:上一时间步的隐藏状态与当前输入的拼接,形成一个新的输入向量。
    • bfb_fbf:遗忘门的偏置项,形状为 (hidden_size,)(hidden\_size,)(hidden_size,)
  2. 计算输入门
    it=σ(Wi⋅[ht−1,xt]+bi)i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)it=σ(Wi[ht1,xt]+bi)

    • iti_tit:输入门的输出,决定哪些新信息将被存入细胞状态。
    • WiW_iWi:输入门的权重矩阵,形状为 (hidden_size,hidden_size+input_size)(hidden\_size, hidden\_size + input\_size)(hidden_size,hidden_size+input_size)
    • bib_ibi:输入门的偏置项,形状为 (hidden_size,)(hidden\_size,)(hidden_size,)
  3. 计算候选细胞状态
    C~t=tanh⁡(Wc⋅[ht−1,xt]+bc)\tilde{C}_t = \tanh(W_c \cdot [h_{t-1}, x_t] + b_c)C~t=tanh(Wc[ht1,xt]+bc)

    • C~t\tilde{C}_tC~t:候选细胞状态,生成可能添加到细胞状态的新信息。
    • tanh⁡\tanhtanh:双曲正切激活函数,输出范围在 (−1,1)(-1, 1)(1,1) 之间。
    • WcW_cWc:候选细胞状态的权重矩阵,形状为 (hidden_size,hidden_size+input_size)(hidden\_size, hidden\_size + input\_size)(hidden_size,hidden_size+input_size)
    • bcb_cbc:候选细胞状态的偏置项,形状为 (hidden_size,)(hidden\_size,)(hidden_size,)
  4. 更新细胞状态
    Ct=ft⊙Ct−1+it⊙C~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_tCt=ftCt1+itC~t

    • CtC_tCt:当前时间步的细胞状态,结合遗忘门和输入门的信息。
    • Ct−1C_{t-1}Ct1:上一时间步的细胞状态。
    • ⊙\odot:逐元素乘法,表示两个向量的逐元素相乘。
  5. 计算输出门
    ot=σ(Wo⋅[ht−1,xt]+bo)o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)ot=σ(Wo[ht1,xt]+bo)

    • oto_tot:输出门的输出,决定当前隐藏状态的输出。
    • WoW_oWo:输出门的权重矩阵,形状为 (hidden_size,hidden_size+input_size)(hidden\_size, hidden\_size + input\_size)(hidden_size,hidden_size+input_size)
    • bob_obo:输出门的偏置项,形状为 (hidden_size,)(hidden\_size,)(hidden_size,)
  6. 更新隐藏状态
    ht=ot⊙tanh⁡(Ct)h_t = o_t \odot \tanh(C_t)ht=ottanh(Ct)

    • hth_tht:当前时间步的隐藏状态,作为输出传递到下一个时间步。
    • tanh⁡(Ct)\tanh(C_t)tanh(Ct):对当前细胞状态进行双曲正切变换,以确保输出在 (−1,1)(-1, 1)(1,1) 范围内。
LSTM计算过程详解与数值示例:
假设参数:
• input_size = 2(每个时间步输入2个特征)
• hidden_size = 3(隐藏状态和细胞状态有3个维度)初始状态:
h_prev = [0, 0, 0]    # 上一时间步的隐藏状态
C_prev = [0, 0, 0]    # 上一时间步的细胞状态
x_t = [0.5, -0.3]     # 当前时间步的输入权重矩阵和偏置项:
# 遗忘门参数
W_f = [[0.1, 0.2, 0.3, 0.4, 0.5],[0.6, 0.7, 0.8, 0.9, 1.0],[1.1, 1.2, 1.3, 1.4, 1.5]]
b_f = [0.1, 0.2, 0.3]# 输入门参数
W_i = [[-0.1, -0.2, -0.3, -0.4, -0.5],[-0.6, -0.7, -0.8, -0.9, -1.0],[-1.1, -1.2, -1.3, -1.4, -1.5]]
b_i = [0.1, 0.2, 0.3]# 候选细胞状态参数
W_c = [[0.2, 0.3, 0.4, 0.5, 0.6],[0.7, 0.8, 0.9, 1.0, 1.1],[1.2, 1.3, 1.4, 1.5, 1.6]]
b_c = [0.1, 0.2, 0.3]# 输出门参数
W_o = [[-0.2, -0.3, -0.4, -0.5, -0.6],[-0.7, -0.8, -0.9, -1.0, -1.1],[-1.2, -1.3, -1.4, -1.5, -1.6]]
b_o = [0.1, 0.2, 0.3]1. 计算遗忘门 (Forget Gate)
目的:决定从细胞状态中丢弃哪些信息
# 拼接输入 [h_prev, x_t] = [0, 0, 0, 0.5, -0.3]
concat = [0, 0, 0, 0.5, -0.3]
# 计算遗忘门:f_t = σ(W_f · [h_prev, x_t] + b_f)
f_t = [σ(0.1*0 + 0.2*0 + 0.3*0 + 0.4*0.5 + 0.5*(-0.3) + 0.1) = σ(0.0 + 0.0 + 0.0 + 0.2 - 0.15 + 0.1) = σ(0.15)0.537,σ(0.6*0 + 0.7*0 + 0.8*0 + 0.9*0.5 + 1.0*(-0.3) + 0.2) = σ(0.0 + 0.0 + 0.0 + 0.45 - 0.3 + 0.2) = σ(0.35)0.586,σ(1.1*0 + 1.2*0 + 1.3*0 + 1.4*0.5 + 1.5*(-0.3) + 0.3) = σ(0.0 + 0.0 + 0.0 + 0.7 - 0.45 + 0.3) = σ(0.55)0.634
]
# 结果:f_t ≈ [0.537, 0.586, 0.634]2. 计算输入门 (Input Gate)
目的:决定哪些新信息存入细胞状态
# 计算输入门:i_t = σ(W_i · [h_prev, x_t] + b_i)
i_t = [σ(-0.1*0 -0.2*0 -0.3*0 -0.4*0.5 -0.5*(-0.3) + 0.1) = σ(0.0 + 0.0 + 0.0 -0.2 + 0.15 + 0.1) = σ(0.05)0.512,σ(-0.6*0 -0.7*0 -0.8*0 -0.9*0.5 -1.0*(-0.3) + 0.2) = σ(0.0 + 0.0 + 0.0 -0.45 + 0.3 + 0.2) = σ(0.05)0.512,σ(-1.1*0 -1.2*0 -1.3*0 -1.4*0.5 -1.5*(-0.3) + 0.3) = σ(0.+ 0.0 + 0.0 -0.7 + 0.45 + 0.3) = σ(0.05)0.512
]
# 结果:i_t ≈ [0.512, 0.512, 0.512]3. 计算候选细胞状态 (Candidate Cell State)
目的:生成可能添加到细胞状态的新信息
# 计算候选细胞状态:C̃_t = tanh(W_c · [h_prev, x_t] + b_c)
C̃_t = [tanh(0.2*0 + 0.3*0 + 0.4*0 +.5*0.5 + 0.6*(-0.3) + 0.1) = tanh(0.0 + 0.0 + 0.0 + 0.25 - 0.18 + 0.1) = tanh(0.17)0.168,tanh(0.7*0 + 0.8*0 + 0.9*0 + 1.0*0.5 + 1.1*(-0.3) + 0.2) = tanh(0.0 + 0.0 + 0.0 + 0.5 - 0.33 + 0.2) = tanh(0.37)0.354,tanh(1.2*0 + 1.3*0 + 1.4*0 + 1.5*0.5 + 1.6*(-0.3) + 0.3) = tanh(0.0 + 0.0 + 0.0 + 0.75 - 0.48 + 0.3) = tanh(0.57)0.515
]
# 结果:C̃_t ≈ [0.168, 0.354, 0.515]4. 更新细胞状态 (Update Cell State)目的:结合遗忘门和输入门的信息更新细胞状态
# 更新细胞状态:C_t = f_t ⊙ C_prev + i_t ⊙ C̃_t
C_t = [0.537*0 + 0.512*0.1680.086,0.586*0 + 0.512*0.3540.181,0.634*0 + 0.512*0.5150.264
]
# 结果:C_t ≈ [0.086, 0.181, 0.264]5. 计算输出门 (Output Gate)
目的:决定输出哪些信息
# 计算输出门:o_t = σ(W_o · [h_prev, x_t] + b_o)
o_t = [σ(-0.2*0 -0.3*0 -0.4*0 -0.5*0.5 -0.6*(-0.3) + 0.1) = σ(0.0 + 0.0 + 0.0 -0.25 + 0.18 + 0.1) = σ(0.03)0.508,σ(-0.7*0 -0.8*0 -0.9*0 -1.0*0.5 -1.1*(-0.3) + 0.2) = σ(0.0 + 0.0 + 0.0 -0.5 + 0.33 + 0.2) = σ(0.03)0.508,σ(-1.2*0 -1.3*0 -1.4*0 -1.5*0.5 -1.6*(-0.3) + 0.3) = σ(0.0 + 0.0 + 0.0 -0.75 + 0.48 + 0.3) = σ(0.03)0.508
]
# 结果:o_t ≈ [0.508, 0.508, 0.508]6. 更新隐藏状态 (Update Hidden State)
目的:生成当前时间步的输出
# 更新隐藏状态:h_t = o_t ⊙ tanh(C_t)
tanh_C_t = [tanh(0.086)0.086, tanh(0.181)0.179, tanh(0.264)0.258]
h_t = [0.508 * 0.0860.044,0.508 * 0.1790.091,0.508 * 0.2580.131
]
# 结果:h_t ≈ [0.044, 0.091, 0.131]最终结果
经过一个时间步的计算,我们得到:
• 新的细胞状态:C_t ≈ [0.086, 0.181, 0.264]
• 新的隐藏状态:h_t ≈ [0.044, 0.091, 0.131]
这些状态将传递到下一个时间步,继续处理序列中的下一个输入。

1.5 计算过程可视化

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle, Circle, Arrow, FancyArrowPatch
import matplotlib as mpl# 设置全局样式
mpl.rcParams['font.size'] = 12
mpl.rcParams['font.family'] = 'DejaVu Sans'# 创建画布
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')
plt.title('LSTM Computation Process - Time Step 0', fontsize=16, pad=20)# 颜色定义
input_color = '#FFD700'  # 金色 - 输入
hidden_color = '#1E90FF'  # 道奇蓝 - 隐藏状态
cell_color = '#32CD32'  # 酸橙绿 - 细胞状态
gate_color = '#FF4500'  # 橙红色 - 门控
active_color = '#FF1493'  # 深粉色 - 激活状态
arrow_color = '#8B0000'  # 深红色 - 连接线# 位置定义
x_pos = 5
y_positions = {'input': 6,'gates': 5,'hidden': 4,'cell': 3,'output_gate': 2
}# 初始状态
h_init = Circle((1, 4), 0.3, facecolor='lightgray', edgecolor='black')
c_init = Circle((1, 3), 0.3, facecolor='lightgray', edgecolor='black')
ax.add_patch(h_init)
ax.add_patch(c_init)
ax.text(1, 4, 'h_{-1}', ha='center', va='center', fontsize=10)
ax.text(1, 3, 'C_{-1}', ha='center', va='center', fontsize=10)# 创建节点 - 时间步0
# 输入节点
input_node = Circle((x_pos, y_positions['input']), 0.3, facecolor=input_color, edgecolor='black', alpha=0.7)
ax.add_patch(input_node)
ax.text(x_pos, y_positions['input'], 'x_0', ha='center', va='center', fontsize=10)# 隐藏状态节点
hidden_node = Circle((x_pos, y_positions['hidden']), 0.3, facecolor=hidden_color, edgecolor='black', alpha=0.7)
ax.add_patch(hidden_node)
ax.text(x_pos, y_positions['hidden'], 'h_0', ha='center', va='center', fontsize=10)# 细胞状态节点
cell_node = Circle((x_pos, y_positions['cell']), 0.3, facecolor=cell_color, edgecolor='black', alpha=0.7)
ax.add_patch(cell_node)
ax.text(x_pos, y_positions['cell'], 'C_0', ha='center', va='center', fontsize=10)# 门控节点
forget_gate = Circle((x_pos - 0.8, y_positions['gates']), 0.2, facecolor=gate_color, edgecolor='black', alpha=0.7)
input_gate = Circle((x_pos, y_positions['gates']), 0.2, facecolor=gate_color, edgecolor='black', alpha=0.7)
candidate_cell = Circle((x_pos + 0.8, y_positions['gates']), 0.2, facecolor=gate_color, edgecolor='black', alpha=0.7)
output_gate = Circle((x_pos, y_positions['output_gate']), 0.2, facecolor=gate_color, edgecolor='black', alpha=0.7)ax.add_patch(forget_gate)
ax.add_patch(input_gate)
ax.add_patch(candidate_cell)
ax.add_patch(output_gate)ax.text(x_pos - 0.8, y_positions['gates'], 'f_0', ha='center', va='center', fontsize=8)
ax.text(x_pos, y_positions['gates'], 'i_0', ha='center', va='center', fontsize=8)
ax.text(x_pos + 0.8, y_positions['gates'], 'C̃_0', ha='center', va='center', fontsize=8)
ax.text(x_pos, y_positions['output_gate'], 'o_0', ha='center', va='center', fontsize=8)# 创建时间步1和2的节点(不展示计算过程)
for t in range(1, 3):x_pos_t = 8 + (t - 1) * 3# 输入节点input_node_t = Circle((x_pos_t, y_positions['input']), 0.3, facecolor=input_color, edgecolor='black', alpha=0.3)ax.add_patch(input_node_t)ax.text(x_pos_t, y_positions['input'], f'x_{t}', ha='center', va='center', fontsize=10, alpha=0.3)# 隐藏状态节点hidden_node_t = Circle((x_pos_t, y_positions['hidden']), 0.3, facecolor=hidden_color, edgecolor='black', alpha=0.3)ax.add_patch(hidden_node_t)ax.text(x_pos_t, y_positions['hidden'], f'h_{t}', ha='center', va='center', fontsize=10, alpha=0.3)# 细胞状态节点cell_node_t = Circle((x_pos_t, y_positions['cell']), 0.3, facecolor=cell_color, edgecolor='black', alpha=0.3)ax.add_patch(cell_node_t)ax.text(x_pos_t, y_positions['cell'], f'C_{t}', ha='center', va='center', fontsize=10, alpha=0.3)# 时间步标签ax.text(x_pos_t, y_positions['input'] + 0.8, f'Time Step {t}', ha='center', fontsize=10, alpha=0.3)# 绘制连接线
arrows = []
arrow_labels = []# 初始到当前时间步的连接
arrow = FancyArrowPatch((1.3, 4), (x_pos - 0.3, 4), arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)
arrow = FancyArrowPatch((1.3, 3), (x_pos - 0.3, 3), arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)# 输入到门控的连接
for gate_y in [y_positions['gates']]:arrow = FancyArrowPatch((x_pos, y_positions['input'] - 0.3), (x_pos, gate_y + 0.2),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)ax.add_patch(arrow)arrows.append(arrow)# 隐藏状态到门控的连接
for gate_x in [x_pos - 0.8, x_pos, x_pos + 0.8]:arrow = FancyArrowPatch((gate_x, y_positions['hidden'] + 0.3), (gate_x, y_positions['gates'] - 0.2),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)ax.add_patch(arrow)arrows.append(arrow)# 门控到细胞状态的连接
arrow = FancyArrowPatch((x_pos - 0.8, y_positions['gates'] - 0.2), (x_pos - 0.3, y_positions['cell'] + 0.3),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)arrow = FancyArrowPatch((x_pos, y_positions['gates'] - 0.2), (x_pos, y_positions['cell'] + 0.3),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)arrow = FancyArrowPatch((x_pos + 0.8, y_positions['gates'] - 0.2), (x_pos + 0.3, y_positions['cell'] + 0.3),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)# 细胞状态到输出门的连接
arrow = FancyArrowPatch((x_pos, y_positions['cell'] - 0.3), (x_pos, y_positions['output_gate'] + 0.2),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)# 输出门到隐藏状态的连接
arrow = FancyArrowPatch((x_pos, y_positions['output_gate'] - 0.2), (x_pos, y_positions['hidden'] + 0.3),arrowstyle='->', color='gray', alpha=0.3, mutation_scale=15)
ax.add_patch(arrow)
arrows.append(arrow)# 添加公式
formula_text = ax.text(7, 1, '', fontsize=14, ha='center', bbox=dict(facecolor='white', alpha=0.8))# 添加图例
legend_elements = [Circle((0, 0), radius=0.3, facecolor=input_color, edgecolor='black', label='Input (x_t)'),Circle((0, 0), radius=0.3, facecolor=hidden_color, edgecolor='black', label='Hidden State (h_t)'),Circle((0, 0), radius=0.3, facecolor=cell_color, edgecolor='black', label='Cell State (C_t)'),Circle((0, 0), radius=0.2, facecolor=gate_color, edgecolor='black', label='Gates'),FancyArrowPatch((0, 0), (0, 0), arrowstyle='->', color=arrow_color, label='Active Connection')
]
ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3)# 节点列表
nodes = {'input': input_node,'hidden': hidden_node,'cell': cell_node,'forget_gate': forget_gate,'input_gate': input_gate,'candidate_cell': candidate_cell,'output_gate': output_gate,'h_init': h_init,'c_init': c_init
}# 动画更新函数
def update(frame):# 重置所有节点颜色for node in nodes.values():if node == h_init or node == c_init:node.set_facecolor('lightgray')elif node in [input_node, hidden_node, cell_node]:if node == input_node:node.set_facecolor(input_color)elif node == hidden_node:node.set_facecolor(hidden_color)elif node == cell_node:node.set_facecolor(cell_color)else:node.set_facecolor(gate_color)node.set_alpha(0.7)# 重置所有连接线for arrow in arrows:arrow.set_alpha(0.3)arrow.set_color('gray')# 根据帧数更新if frame == 0:# 初始状态formula_text.set_text('Initialization: $h_{-1} = 0, C_{-1} = 0$')nodes['h_init'].set_facecolor(active_color)nodes['c_init'].set_facecolor(active_color)nodes['h_init'].set_alpha(1.0)nodes['c_init'].set_alpha(1.0)elif frame == 1:# 遗忘门计算formula_text.set_text('1. Forget Gate: $f_0 = \\sigma(W_f \\cdot [h_{-1}, x_0] + b_f)$')nodes['input'].set_facecolor(active_color)nodes['h_init'].set_facecolor(active_color)nodes['forget_gate'].set_facecolor(active_color)nodes['input'].set_alpha(1.0)nodes['h_init'].set_alpha(1.0)nodes['forget_gate'].set_alpha(1.0)# 激活相关连接arrows[0].set_alpha(1.0)  # h_{-1} -> h_0arrows[0].set_color(arrow_color)arrows[2].set_alpha(1.0)  # x_0 -> gatesarrows[2].set_color(arrow_color)arrows[3].set_alpha(1.0)  # h_{-1} -> f_0arrows[3].set_color(arrow_color)elif frame == 2:# 输入门计算formula_text.set_text('2. Input Gate: $i_0 = \\sigma(W_i \\cdot [h_{-1}, x_0] + b_i)$')nodes['input'].set_facecolor(active_color)nodes['h_init'].set_facecolor(active_color)nodes['input_gate'].set_facecolor(active_color)nodes['input'].set_alpha(1.0)nodes['h_init'].set_alpha(1.0)nodes['input_gate'].set_alpha(1.0)# 激活相关连接arrows[0].set_alpha(1.0)  # h_{-1} -> h_0arrows[0].set_color(arrow_color)arrows[2].set_alpha(1.0)  # x_0 -> gatesarrows[2].set_color(arrow_color)arrows[4].set_alpha(1.0)  # h_{-1} -> i_0arrows[4].set_color(arrow_color)elif frame == 3:# 候选细胞状态计算formula_text.set_text('3. Candidate Cell State: $\\tilde{C}_0 = \\tanh(W_c \\cdot [h_{-1}, x_0] + b_c)$')nodes['input'].set_facecolor(active_color)nodes['h_init'].set_facecolor(active_color)nodes['candidate_cell'].set_facecolor(active_color)nodes['input'].set_alpha(1.0)nodes['h_init'].set_alpha(1.0)nodes['candidate_cell'].set_alpha(1.0)# 激活相关连接arrows[0].set_alpha(1.0)  # h_{-1} -> h_0arrows[0].set_color(arrow_color)arrows[2].set_alpha(1.0)  # x_0 -> gatesarrows[2].set_color(arrow_color)arrows[5].set_alpha(1.0)  # h_{-1} -> C̃_0arrows[5].set_color(arrow_color)elif frame == 4:# 更新细胞状态formula_text.set_text('4. Update Cell State: $C_0 = f_0 \\odot C_{-1} + i_0 \\odot \\tilde{C}_0$')nodes['forget_gate'].set_facecolor(active_color)nodes['input_gate'].set_facecolor(active_color)nodes['candidate_cell'].set_facecolor(active_color)nodes['c_init'].set_facecolor(active_color)nodes['cell'].set_facecolor(active_color)nodes['forget_gate'].set_alpha(1.0)nodes['input_gate'].set_alpha(1.0)nodes['candidate_cell'].set_alpha(1.0)nodes['c_init'].set_alpha(1.0)nodes['cell'].set_alpha(1.0)# 激活相关连接arrows[1].set_alpha(1.0)  # C_{-1} -> C_0arrows[1].set_color(arrow_color)arrows[6].set_alpha(1.0)  # f_0 -> C_0arrows[6].set_color(arrow_color)arrows[7].set_alpha(1.0)  # i_0 -> C_0arrows[7].set_color(arrow_color)arrows[8].set_alpha(1.0)  # C̃_0 -> C_0arrows[8].set_color(arrow_color)elif frame == 5:# 输出门计算formula_text.set_text('5. Output Gate: $o_0 = \\sigma(W_o \\cdot [h_{-1}, x_0] + b_o)$')nodes['input'].set_facecolor(active_color)nodes['h_init'].set_facecolor(active_color)nodes['output_gate'].set_facecolor(active_color)nodes['input'].set_alpha(1.0)nodes['h_init'].set_alpha(1.0)nodes['output_gate'].set_alpha(1.0)# 激活相关连接arrows[0].set_alpha(1.0)  # h_{-1} -> h_0arrows[0].set_color(arrow_color)arrows[2].set_alpha(1.0)  # x_0 -> gatesarrows[2].set_color(arrow_color)elif frame == 6:# 更新隐藏状态formula_text.set_text('6. Update Hidden State: $h_0 = o_0 \\odot \\tanh(C_0)$')nodes['output_gate'].set_facecolor(active_color)nodes['cell'].set_facecolor(active_color)nodes['hidden'].set_facecolor(active_color)nodes['output_gate'].set_alpha(1.0)nodes['cell'].set_alpha(1.0)nodes['hidden'].set_alpha(1.0)# 激活相关连接arrows[9].set_alpha(1.0)  # C_0 -> o_0arrows[9].set_color(arrow_color)arrows[10].set_alpha(1.0)  # o_0 -> h_0arrows[10].set_color(arrow_color)return list(nodes.values()) + arrows + [formula_text]# 创建动画
animation = FuncAnimation(fig, update, frames=range(7),interval=1500, blit=True)plt.tight_layout()
animation.save('lstm_time_step_0.gif', writer='pillow', fps=1, dpi=100)
plt.show()

在这里插入图片描述

二、代码示例

通过两层LSTM处理一段音频频谱,打印每层的输出形状、参数形状,并可视化特征图。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import librosa
import numpy as np# 定义 LSTM 模型
class LSTMModel(nn.Module):def __init__(self, input_size):super(LSTMModel, self).__init__()self.lstm1 = nn.LSTM(input_size, 100, batch_first=True)self.lstm2 = nn.LSTM(100, 64, batch_first=True)def forward(self, x):h_out1, _ = self.lstm1(x)h_out2, _ = self.lstm2(h_out1)return h_out1, h_out2  # 返回两层的输出# 读取音频文件并处理
file_path = 'test.wav'
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True)# 选取 3 秒的数据
start_sample = int(1.5 * sample_rate)
end_sample = int(4.5 * sample_rate)
audio_segment = waveform[start_sample:end_sample]# 转换为频谱
n_fft = 512
hop_length = 256
spectrogram = librosa.stft(audio_segment, n_fft=n_fft, hop_length=hop_length)
spectrogram_db = librosa.amplitude_to_db(np.abs(spectrogram))
spectrogram_tensor = torch.tensor(spectrogram_db, dtype=torch.float32).unsqueeze(0)
spectrogram_tensor = spectrogram_tensor.permute(0, 2, 1)
print(f"Spectrogram tensor shape: {spectrogram_tensor.shape}")# 创建 LSTM 模型实例
input_size = spectrogram_tensor.shape[2]
model = LSTMModel(input_size)# 前向传播
lstm_output1, lstm_output2 = model(spectrogram_tensor)# 打印输出形状
print(f"LSTM Layer 1 output shape: {lstm_output1.shape}")
print(f"LSTM Layer 2 output shape: {lstm_output2.shape}")# 打印每层的参数形状
print(f"LSTM Layer 1 weights shape: {model.lstm1.weight_ih_l0.shape}")
print(f"LSTM Layer 1 hidden weights shape: {model.lstm1.weight_hh_l0.shape}")
print(f"LSTM Layer 1 bias shape: {model.lstm1.bias_ih_l0.shape}")print(f"LSTM Layer 2 weights shape: {model.lstm2.weight_ih_l0.shape}")
print(f"LSTM Layer 2 hidden weights shape: {model.lstm2.weight_hh_l0.shape}")
print(f"LSTM Layer 2 bias shape: {model.lstm2.bias_ih_l0.shape}")# 可视化原始频谱
plt.figure(figsize=(10, 4))
plt.imshow(spectrogram_db, aspect='auto', origin='lower', cmap='inferno')
plt.title("Original Spectrogram")
plt.xlabel("Time Frames")
plt.ylabel("Frequency Bins")
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()# 可视化 LSTM 输出的特征图
plt.figure(figsize=(10, 4))# 绘制第一层 LSTM 输出的特征图
plt.subplot(2, 1, 1)
plt.imshow(lstm_output1[0].detach().numpy().T, aspect='auto', origin='lower', cmap='inferno')  # 转置
plt.title("LSTM Layer 1 Output Feature Map")
plt.xlabel("Time Steps")
plt.ylabel("Hidden State Dimensions")
plt.colorbar(label='Hidden State Value')# 绘制第二层 LSTM 输出的特征图
plt.subplot(2, 1, 2)
plt.imshow(lstm_output2[0].detach().numpy().T, aspect='auto', origin='lower', cmap='inferno')  # 转置
plt.title("LSTM Layer 2 Output Feature Map")
plt.xlabel("Time Steps")
plt.ylabel("Hidden State Dimensions")
plt.colorbar(label='Hidden State Value')plt.tight_layout()
plt.show()

在这里插入图片描述
在这里插入图片描述

Spectrogram tensor shape: torch.Size([1, 188, 257])
LSTM Layer 1 output shape: torch.Size([1, 188, 100])
LSTM Layer 2 output shape: torch.Size([1, 188, 64])
LSTM Layer 1 weights shape: torch.Size([400, 257])
LSTM Layer 1 hidden weights shape: torch.Size([400, 100])
LSTM Layer 1 bias shape: torch.Size([400])
LSTM Layer 2 weights shape: torch.Size([256, 100])
LSTM Layer 2 hidden weights shape: torch.Size([256, 64])
LSTM Layer 2 bias shape: torch.Size([256])

PyTorch LSTM 参数详解

在 LSTM 的数学表示中:

  • Wf,Wi,Wc,WoW_f, W_i, W_c, W_oWf,Wi,Wc,Wo 表示完整的权重矩阵(形状:(hidden_size, input_size + hidden_size)

在 PyTorch 实现中,这些权重被拆分为两部分:

  • 输入部分:存储在 weight_ih_l0weight\_ih\_l0weight_ih_l0
  • 隐藏部分:存储在 weight_hh_l0weight\_hh\_l0weight_hh_l0

1. weight_ih_l0

  • 含义:输入到隐藏的权重(Input to Hidden weights)
  • 形状(4×hidden_size,input_size)(4 \times hidden\_size, input\_size)(4×hidden_size,input_size)
  • 内容
    • 包含所有四个门(输入门、遗忘门、细胞状态、输出门)的输入部分权重, 按顺序堆叠:
    • 第0部分:输入门的输入权重(Wi(x)W_i^{(x)}Wi(x)
    • 第1部分:遗忘门的输入权重(Wf(x)W_f^{(x)}Wf(x)
    • 第2部分:细胞状态的输入权重(Wc(x)W_c^{(x)}Wc(x)
    • 第3部分:输出门的输入权重(Wo(x)W_o^{(x)}Wo(x)
  • 数学关系
    Wi=[Ui,Wi(x)]W_i = [U_i, W_i^{(x)}] Wi=[Ui,Wi(x)]
    其中 WiW_iWi 是数学表示中的完整权重矩阵。

2. weight_hh_l0

  • 含义:隐藏到隐藏的权重(Hidden to Hidden weights)
  • 形状(4×hidden_size,hidden_size)(4 \times hidden\_size, hidden\_size)(4×hidden_size,hidden_size)
  • 内容: 包含所有四个门的隐藏状态部分权重,按顺序堆叠:
    • 第0部分:输入门的隐藏权重(UiU_iUi
    • 第1部分:遗忘门的隐藏权重(UfU_fUf
    • 第2部分:细胞状态的隐藏权重(UcU_cUc
    • 第3部分:输出门的隐藏权重(UoU_oUo
  • 数学关系
    Wi=[Ui,Wi(x)]W_i = [U_i, W_i^{(x)}] Wi=[Ui,Wi(x)]
    其中 WiW_iWi 是数学表示中的完整权重矩阵。

3. bias_ih_l0

  • 含义:输入到隐藏的偏置(Input to Hidden bias)
  • 形状(4×hidden_size)(4 \times hidden\_size)(4×hidden_size)
  • 内容:包含所有四个门的偏置项,按顺序堆叠:
    • 第0部分:输入门偏置(bib_ibi
    • 第1部分:遗忘门偏置(bfb_fbf
    • 第2部分:细胞状态偏置(bcb_cbc
    • 第3部分:输出门偏置(bob_obo

4. bias_hh_l0

  • 含义:隐藏到隐藏的偏置(Hidden to Hidden bias)
  • 形状(4×hidden_size)(4 \times hidden\_size)(4×hidden_size)
  • 内容:包含所有四个门的隐藏状态偏置,按顺序堆叠:
    • 第0部分:输入门隐藏偏置(bi(h)b_i^{(h)}bi(h)
    • 第1部分:遗忘门隐藏偏置(bf(h)b_f^{(h)}bf(h)
    • 第2部分:细胞状态隐藏偏置(bc(h)b_c^{(h)}bc(h)
    • 第3部分:输出门隐藏偏置(bo(h)b_o^{(h)}bo(h)
  • 实际使用
    • 在计算中,总偏置为 b=bias_ih_l0+bias_hh_l0b = bias\_ih\_l0 + bias\_hh\_l0b=bias_ih_l0+bias_hh_l0
    • bi=bi+bi(h)b_i = b_i + b_i^{(h)}bi=bi+bi(h)(逐元素相加)

文章转载自:

http://Z6Y8xspd.cybch.cn
http://IDWpjhwJ.cybch.cn
http://7Ux3agSZ.cybch.cn
http://LA9otfn1.cybch.cn
http://BTaNIFAl.cybch.cn
http://PIBzqdPc.cybch.cn
http://DC4RnF7F.cybch.cn
http://q1AVby9h.cybch.cn
http://Bozw6R3K.cybch.cn
http://3aeQBjzO.cybch.cn
http://4Ku45QWY.cybch.cn
http://FjT5VMrb.cybch.cn
http://MpYZt7Am.cybch.cn
http://YojAbyys.cybch.cn
http://PtTSSizH.cybch.cn
http://UVLZpk8X.cybch.cn
http://vk5qA7Fz.cybch.cn
http://QlANl4Do.cybch.cn
http://lleTkVeX.cybch.cn
http://19KQo1p0.cybch.cn
http://LhJUDfW5.cybch.cn
http://eHvnOAPu.cybch.cn
http://JLBmEGQV.cybch.cn
http://wYjYriKL.cybch.cn
http://vJEIy2cX.cybch.cn
http://IGgB0cFl.cybch.cn
http://udPuVX13.cybch.cn
http://j3PmgdzT.cybch.cn
http://xbPX8EJv.cybch.cn
http://fPExm9DQ.cybch.cn
http://www.dtcms.com/a/386877.html

相关文章:

  • 初始化Vue3 项目
  • 耕地质量评价
  • MeloTTS安装实践
  • 国产化芯片ZCC3790--同步升降压控制器的全新选择, 替代LT3790
  • LeetCode 977.有序数组的平方
  • 佳易王个体诊所中西医电子处方管理系统软件教程详解:开方的时候可一键导入配方模板,自由添加模板
  • C#实现WGS-84到西安80坐标系转换的完整指南
  • rabbitmq面试题总结
  • 【Java初学基础】⭐Object()顶级父类与它的重要方法equals()
  • C语言初尝试——洛谷
  • Kaleidoscope for Mac:Mac 平台文件与图像差异对比的终极工具
  • LeetCode 刷题【80. 删除有序数组中的重复项 II】
  • 淘宝扭蛋机小程序系统开发:引领电商娱乐化潮流
  • 【车载audio开发】【基础概念2】【Usage、ContentType、Flags、SessionId之间的关系】
  • 【Day 52 】Linux-Jenkins
  • 向内核社区提交补丁
  • 【Java-常用类】
  • 在线教程丨ACL机器翻译大赛30个语种摘冠,腾讯Hunyuan-MT-7B支持33种语言翻译
  • 006 Rust基本数据类型
  • docker配置代理加速
  • 基于MATLAB的视频动态目标跟踪检测实现方案
  • AirPods Pro 3正式发布:全方位升级​
  • PyTorch生成式人工智能(29)——基于Transformer生成音乐
  • 《如龙8外传》共五章:漂流记、老人与海、金银岛等!
  • NVIDIA DOCA与BlueField DPU理解与学习
  • 蜜罐--攻防、护网的强大助力
  • OpenStack 学习笔记
  • 2025年09月16日Github流行趋势
  • git永久存储凭证(可以不用经常输入git密钥)
  • 豆包对高可用系统建设的理解