当前位置: 首页 > news >正文

07_GRU模型

GRU模型

双向GRU笔记:https://blog.csdn.net/weixin_44579176/article/details/146459952

概念

  • GRU(Gated Recurrent Unit)也称为门控循环单元,是一种改进版的RNN。与LSTM一样能够有效捕捉长序列之间的语义关联,通过引入两个"门"机制(重置门和更新门)来控制信息的流动,从而避免了传统RNN中的梯度消失问题,并减少了LSTM模型中的复杂性。

    [^ 要点]:1.GRU同样是通过门机制来解决传统RNN中的梯度消失问题的 2.GRU相比于LSTM更为简洁,它只引入了两个门 :更新门(Update Gate), 重置门(Reset Gate)

核心组件

  1. 重置门(Reset Gate)

    • 作用: 决定如何将新的输入与之前的隐藏状态结合。

      • 当重置门值接近0时,表示当前时刻的输入几乎不依赖上一时刻的隐藏状态。
      • 当重置门值接近1时,表示当前时刻的输入几乎完全依赖上一时刻的隐藏状态。
    • 公式(变体版本): r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = σ(W_r·[h_{t-1},x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

      • r t r_t rt| 重置门值, r t ∈ ( 0 , 1 ) r_t ∈ (0,1) rt(0,1)
      • W r W_r Wr 和$ b_r$ | 重置门权值和偏置项
      • σ | sigmoid函数 保证 r t r_t rt的输出值在 0 到 1之间
  2. 更新门(Update Gate)

    • 作用: 决定多少之前的信息需要保留,多少新的信息需要更新。

      • 当更新门值接近0时,意味着网络只记住旧的隐藏状态,几乎没有新的信息。
      • 当更新门值接近1时,意味着网络更倾向于使用新的隐藏状态,记住当前输入的信息。
    • 公式(变体版本): z t = σ ( W r ⋅ [ h t − 1 , x t ] + b z ) z_t = σ(W_r·[h_{t-1},x_t] + b_z) zt=σ(Wr[ht1,xt]+bz)

      • z t z_t zt| 更新门值, z t ∈ ( 0 , 1 ) z_t ∈ (0,1) zt(0,1)
      • W r W_r Wr 和$ b_r$ | 重置门权值和偏置项
      • σ | sigmoid函数 保证 z t z_t zt的输出值在 0 到 1之间
  3. 候选隐藏状态(Candidate Hidden State)

    • 作用: 捕捉当前时间步的信息,多少前一隐藏状态的信息被保留。

    • 公式(变体版本): h ^ t = t a n h ( W h ⋅ [ r t ⊙ h t − 1 , x t ] + b h ) ĥ_t = tanh(W_h · [r_t \odot h_{t-1} , x_t] + b_h) h^t=tanh(Wh[rtht1,xt]+bh)

      • h ^ t ĥ_t h^t| 候选隐藏状态值, h ^ t ∈ ( − 1 , 1 ) ĥ_t ∈ (-1,1) h^t(1,1)
      • W h W_h Wh 和$ b_h$ | 候选隐藏状态的权重和偏置项
      • tanh| 双曲正切函数 保证 h t h_t ht的输出值在 -1 到 1之间
      • ⊙ \odot | Hadamard Product
  4. 最终隐藏状态(Final Hidden State)

    • 作用: 控制信息更新,传递长期依赖。

    • 公式(变体版本): h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ^ t h_t = (1-z_t) \odot h_{t-1} + z_t \odot ĥ_t ht=(1zt)ht1+zth^t

      • h t h_t ht| 当前时间步的隐藏状态
      • z t z_t zt | 更新门的输出,控制新旧信息的比例
      • ⊙ \odot | Hadamard Product

    重置门与更新的对比

    门控机制核心功能直观理解
    重置门(Reset Gate)控制历史信息对当前候选状态的影响:决定是否忽略部分或全部历史信息,从而生成新的候选隐藏状态。“是否忘记过去,重新开始?”(例如:处理句子中的突变或新段落)
    更新门(Update Gate)控制新旧信息的融合比例:决定保留多少旧状态的信息,同时引入多少候选状态的新信息。“保留多少旧记忆,吸收多少新知识?”(例如:维持长期依赖关系)

    重置门作用举例:

    ​ input: [‘风’,‘可以’,‘吹起’,‘一大张’,‘白纸’,‘’,‘无法’,‘吹走’,‘一只’,‘蝴蝶’,‘因为’,‘生命’,‘的’,‘力量’,‘在于’,‘不’,‘顺从’]

    • 当处理到 ‘却’ 时,上文信息 : 风可以吹起一大张白纸

      • 重置门值 : r t = 0.3 r_t = 0.3 rt=0.3
        • 作用:忽略部分历史信息,弱化上文影响,为后续信息(无法吹走一只蝴蝶)腾出空间
      • 更新门值 : z t = 0.8 z_t = 0.8 zt=0.8
        • 作用: 表示保留更多候选隐藏状态(由于 r t r_t rt是一个较小的值,所以候选隐藏状态中新信息占比更大) 的信息

      [^ 注]: 此时$ h_t $接近 $ ĥ_t$,隐藏状态被重置为“准备处理转折后的新逻辑”。

    • 当处理到 ‘因为’ 时,上文信息 : 少部分的 "风可以吹起一大张白纸 " + 大部分的 “无法吹走蝴蝶”

      • 重置门值 : r t = 0.8 r_t = 0.8 rt=0.8
        • 作用:保留更多上文信息,以便与后续原因关联
      • 更新门值 : z t = 0.5 z_t = 0.5 zt=0.5
        • 作用: 平衡旧状态(上文结论) 和 新状态(下文原因) ,逐步构建完整的逻辑链

内部结构

在这里插入图片描述
在这里插入图片描述

  • GRU的更新门和重置门结构图

在这里插入图片描述

Pytorch实现

nn.GRU(input_size, hidden_size, num_layers, bidirectional, batch_first, dropout)

[^ input_size ]:输入特征的维度
[^ hidden_size ]:隐藏状态的维度
[^ num_layers ]:GRU的层数(默认值为1)
[^ batch_first ]:如果为True,输入和输出的形状为 (batch_size, seq_len, input_size);否则为 (seq_len, batch_size, input_size)
[^ bidirectional ]:如果为True,使用双向GRU;否则为单向GRU(默认False)
[^ dropout ]:在多层GRU中,是否在层之间应用dropout(默认值为0)
使用示例
# 定义GRU的参数含义: (input_size, hidden_size, num_layers)
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量的参数含义: (num_layers * num_directions, batch_size, hidden_size)
import torch.nn as nn
import torch

def dm_gru():
    # 创建GRU层
    gru = nn.GRU(input_size=5, hidden_size=6, num_layers=2)
    # 创建输入张量
    input = torch.randn(size=(1, 3, 5))
    # 初始化隐藏状态
    h0 = torch.randn(size=(2, 3, 6))
    # hn输出两层隐藏状态, 最后1个隐藏状态值等于output输出值
    output, hn = gru(input, h0)
    print('output--->', output.shape, output)
    print('hn--->', hn.shape, hn)

相关文章:

  • ChatGPT vs DeepSeek vs Copilot vs Claude:谁将问鼎AI王座?
  • HTML 表单处理进阶:验证与提交机制的学习心得与进度(一)
  • 优选算法的睿智之林:前缀和专题(一)
  • Codeforces Round 1012 (Div. 2)(ABCD)
  • 【Vue3入门2】02-记事本案例
  • redis命令
  • 并查集(竞赛)
  • 生活电子类常识——搭建openMauns工作流+搭建易犯错解析
  • STM32单片机uCOS-Ⅲ系统10 内存管理
  • visual studio code 开发STM32步骤
  • 使用Python开发智能家居系统:基于语音命令的设备控制
  • 常⻅中间件漏洞--Tomcat
  • 深度解析 BPaaS:架构、原则与研发模式探索
  • 《Operating System Concepts》阅读笔记:p471-p472
  • Python常用库全解析:从数据处理到机器学习
  • leetcode0560. 和为 K 的子数组-medium
  • fatal: Unable to create /.git/index.lock‘: File exists.
  • WRC世界机器人大会-2024年展商汇总
  • 可发1区的创新思路:​基于K-means聚类的EMD-BiLSTM-CNN-Attention时间序列预测模型(功率预测、寿命预测、流量预测、故障诊断)
  • 链表相关知识总结
  • 证监会:坚决拥护党中央对王建军进行纪律审查和监察调查决定
  • 夜读丨春天要去动物园
  • “五一”逃离城市计划:带上帐篷去大自然里充电
  • 上海“模速空间”:将形成人工智能“北斗七星”和群星态势
  • 夜读丨跷脚牛肉乐翘脚
  • 豆神教育:2024年净利润1.37亿元,同比增长334%