当前位置: 首页 > news >正文

58、深度学习-自学之路-自己搭建深度学习框架-19、RNN神经网络梯度消失和爆炸的原因(从公式推导方向来说明),通过RNN的前向传播和反向传播公式来理解。

一、RNN神经网络的前向传播图如下:

时间步 t=1:
x₁ → (W_x) → [RNN Cell] → h₁ → (W_y) → y₁
           ↑ (W_h)
          h₀ (初始隐藏状态)

时间步 t=2:
x₂ → (W_x) → [RNN Cell] → h₂ → (W_y) → y₂
           ↑ (W_h)
          h₁

时间步 t=3:
x₃ → (W_x) → [RNN Cell] → h₃ → (W_y) → y₃
           ↑ (W_h)
          h₂

过程解释

时间步 t=1:

  1. 输入

    • 输入 x₁ 是第一个时间步的输入数据(例如,一个词向量或时间序列数据点)。

  2. 权重作用

    • 输入 x₁ 通过权重矩阵 W_x 进行线性变换:

                W_x · x₁

               2. 初始隐藏状态 h₀ 通过权重矩阵 W_h 进行线性变换:

                W_h · h₀        

   3.RNN Cell 计算

  • 将变换后的输入和隐藏状态相加,并加上偏置项 b_h,然后通过激活函数 σ(如 tanh 或 ReLU):

        h₁ = σ(W_h · h₀ + W_x · x₁ + b_h)    

  h₁ 是第一个时间步的隐藏状态,包含了当前输入 x₁ 和前一个隐藏状态 h₀ 的信息。

4.输出计算

  • 隐藏状态 h₁ 通过权重矩阵 W_y 进行线性变换,并加上偏置项 b_y,然后通过激活函数 σ

        y₁ = σ(W_y · h₁ + b_y)

  • y₁ 是第一个时间步的输出(例如,预测的下一个词或时间序列值)。

时间步 t=2:

x₂ → (W_x) → [RNN Cell] → h₂ → (W_y) → y₂
           ↑ (W_h)
          h₁

过程解释

  1. 输入

    • 输入 x₂ 是第二个时间步的输入数据。

  2. 权重作用

    • 输入 x₂ 通过权重矩阵 W_x 进行线性变换:

      W_x · x₂

    • 前一个隐藏状态 h₁ 通过权重矩阵 W_h 进行线性变换:

      W_h · h₁

  3. RNN Cell 计算

    • 将变换后的输入和隐藏状态相加,并加上偏置项 b_h,然后通过激活函数 σ

      h₂ = σ(W_h · h₁ + W_x · x₂ + b_h)

    • h₂ 是第二个时间步的隐藏状态,包含了当前输入 x₂ 和前一个隐藏状态 h₁ 的信息。

  4. 输出计算

    • 隐藏状态 h₂ 通过权重矩阵 W_y 进行线性变换,并加上偏置项 b_y,然后通过激活函数 σ

      y₂ = σ(W_y · h₂ + b_y)

    • y₂ 是第二个时间步的输出。

时间步 t=3

复制

x₃ → (W_x) → [RNN Cell] → h₃ → (W_y) → y₃
           ↑ (W_h)
          h₂
过程解释
  1. 输入

    • 输入 x₃ 是第三个时间步的输入数据。

  2. 权重作用

    • 输入 x₃ 通过权重矩阵 W_x 进行线性变换:

      W_x · x₃

    • 前一个隐藏状态 h₂ 通过权重矩阵 W_h 进行线性变换:

      W_h · h₂

  3. RNN Cell 计算

    • 将变换后的输入和隐藏状态相加,并加上偏置项 b_h,然后通过激活函数 σ

      h₃ = σ(W_h · h₂ + W_x · x₃ + b_h)

    • h₃ 是第三个时间步的隐藏状态,包含了当前输入 x₃ 和前一个隐藏状态 h₂ 的信息。

  4. 输出计算

    • 隐藏状态 h₃ 通过权重矩阵 W_y 进行线性变换,并加上偏置项 b_y,然后通过激活函数 σ

      y₃ = σ(W_y · h₃ + b_y)

    • y₃ 是第三个时间步的输出。

通过上面的公式的观察,大家可以看到一个问题就是:

一共有3个时间步,也就是信息向前传播了三次,然后每次传播使用的输入的权重层是同一个权重、隐藏层对应的权重层是是同一个权重值。

第一个神经网络输出的隐藏层h1给到第二个神经网络,

第二个神经网络的输出隐藏层h2给到第三个神经网络。

第三个神经网络的输出隐藏层h3应该会给到第四个神经网络,如果有的话。

然后每一个神经网络都会有一个预测值y1 y2 y3

如果我们的输出方式是多输入多输出,那么我们每一个预测值y1 y2 y3都会对应一个真实值ture1、ture2、ture3

然后对应着三个误差值loss1 loss2 loss3,然后把loss1 + loss2 + loss3 =L

这个就是前向传播的过程。

二、RNN神经网络的反向传播:

反向传播我们从T=3开始往后传播:

首先:y₃ = σ(W_y · h₃ + b_y)

瞬时函数为L

(1)计算输出层的梯度

损失函数对输出 y_t 的梯度:这个里面的t 你可以认为是3,方便你理解。

        

损失函数对隐藏状态 h_t 的梯度:h_3方便你理解。

然后我们考虑一下,从L 到 h_3所经过的路线:

 L  --->  y3  --->h_3

上式子中:

:t=3

              :t=3

                   :T=3,前面传递的三个预测值和真实值之间的误差值之和

从L到h3 经过了L  --->  y3  --->h_3这个路径。

所以L 对h_3的求导为:

其实我们求导到最后应该要要对权重的求导,因为最后要通过修改权重来学习内容。

也就是:

第一个是误差对输入层权重的导数

第二个是误差对隐藏层权重的导数

第三个是误差对预测层权重的导数

然后我们先看一下误差对输入层权重的求导,

现在我们考虑一下L到Wx的路径有哪些,

第一条是:L --->y3 --->h3 --->Wx

第二条是:L --->y3 --->h3 --->h2 --->Wx

第三条是:L --->y3 --->h3 --->h2 --->h1 --->Wx

然后把第一条路径的导数加上第二条路径的导数再加上第三条路径的导数就是L对Wx的求导。

这个是我自己推导的,可能有些地方不够严谨,但是具体的过程是正确的,从最后的公式我们可以看到距离Y3 最远的X1的前面的值是激活函数的导数的三次方乘以隐藏层权重Wh的平方,那么如果权重的值远小于1,平方后再乘以激活函数的3次方肯定已经远小于1,非常接近0了,那么最后由x1能给L对Wx的导数的值提供的影响就大大减小了,那么x1对L的影响就大大减小了,那么就导致了梯度的消失。对x1信息的遗忘。所有RNN不能够处理很长文本的原因。

梯度爆炸是因为,如果Wh很大比如说是100,平方就是10000,激活函数是0.2,最后的结果就是

80,那么就是说x1对整体的影响可以达到80这么多。如果再传递一层,就更大了。这就导致了梯度爆炸,误差难以收敛。

还有L对Wh的导数

L对Wy的导数都是同样推导的。

不知道大家能不能理解。可以自己动手推导一下。然后就好理解了。

http://www.dtcms.com/a/45462.html

相关文章:

  • 商城源码的框架
  • JAVA学习笔记038——bean的概念和常见注解标注
  • 计算机毕业设计SpringBoot+Vue.js体育馆使用预约平台(源码+文档+PPT+讲解)
  • Pytest之fixture的常见用法
  • AI人工智能机器学习之监督线性模型
  • 【广度优先搜索】图像渲染 岛屿数量
  • 7-1JVMCG垃圾回收
  • 【文献阅读】A Survey Of Resource-Efficient LLM And Multimodal Foundation Models
  • 如何保证 Redis 缓存和数据库的一致性?
  • 在编译Linux的内核镜像和模块时,必须先编译内核镜像,再编译模块,顺序不可随意调整的原因
  • 备战蓝桥杯Day11 DFS
  • React 常见面试题及答案
  • Mysql系统表
  • 【考试大纲】中级信息安全工程师考试大纲
  • HTMLS基本结构及标签
  • 神经网络之CNN图像识别(torch api 调用)
  • 建易WordPress
  • 算法-二叉树篇23-二叉搜索树中的插入操作
  • 夜天之书 #106 Apache 软件基金会如何投票选举?
  • Java 大视界 -- Java 大数据在智能安防入侵检测与行为分析中的应用(108)
  • AF3 DataPipeline类process_core 方法解读
  • sql server 版本更新日期
  • 经典算法 金币阵列问题
  • 【SpringCloud】黑马微服务学习笔记
  • 考虑复杂遭遇场景下的COLREG,基于模型预测人工势场的船舶运动规划方法附Matlab代码
  • 【Nginx】在Windows服务器上用Nginx部署Vue前端全流程(附避坑指南)
  • 全监督、半监督、弱监督、无监督
  • 【Python篇】PyQt5 超详细教程——由入门到精通(序篇)
  • PDF文件转换为PNG图像
  • Kubernetes kubelet inotify