当前位置: 首页 > news >正文

第二章、LSTM(Long Short-term Memory:长短时记忆网络)

0 前言

RNN(循环神经网络)本身存在各种各样的缺陷,比如梯度弥散、梯度爆炸和短时记忆的问题。为弥补RNN的这些问题,瑞士人工智能科学家于1997提出了Long Short-term Memory(长短时记忆网络),即现在常用的LSTM。

1 RNN的局限性

以下阐述流程

  • 问题出现的原因
  • 直观的解决问题的方法

循环神经网络会出现这三个问题的绝大多数原因取决于其参数梯度中的δhtδhi\frac{\delta h_t}{\delta h_i}δhiδht这一项。其展开如下所示,此处不做推导:
δhtδhi=Πj=it−1diag(σ′(Wxhxj+1+Whhhj+b))Whh\frac{\delta h_t}{\delta h_i}=\Pi^{t-1}_{j=i}diag(\sigma'(W_{xh}x_{j+1}+W_{hh}h_j+b))W_{hh}δhiδht=Πj=it1diag(σ(Wxhxj+1+Whhhj+b))Whh

观察上式我们发现实际上这个式子中存在WhhW_{hh}Whh的连乘运算,那么如果矩阵WhhW_{hh}Whh的最大特征值小于1,连乘会导致δhtδhi\frac{\delta h_t}{\delta h_i}δhiδht趋近于0,这就导致了梯度弥散。相对应的如果该值大于1,则会导致δhtδhi\frac{\delta h_t}{\delta h_i}δhiδht值爆炸式增长,即梯度爆炸。

1.1 梯度爆炸

很自然的,因为某个值过大而产生的问题,我们可以通过限制该值来解决。我们可以做梯度裁减,使WWW中的所有元素都在一定范围内就可以了。

  • 假设张量为WWW,令所有元素wij∈[min,max]w_{ij}\in[min,max]wij[min,max]
  • 假设张量为WWW,限制张量的二范数∣∣W∣∣2∈[0,max]||W||_2\in[0,max]∣∣W2[0,max],若∣∣W∣∣2>max||W||_2>max∣∣W2>max,则令W′=W∣∣W∣∣2⋅maxW'=\frac{W}{||W||_2}\cdot maxW=∣∣W2Wmax
  • 假设张量为WWW,考虑全局范数裁减,令global_norm=∑i∣∣W(i)∣∣22global\_norm=\sqrt{\sum_i{||W^{(i)}||_2 }^2}global_norm=i∣∣W(i)22,则有W(i)=W(i)⋅max_normmax(global_norm,max_norm)W^{(i)}=\frac{W^{(i)}\cdot max\_norm}{max(global\_norm,max\_norm)}W(i)=max(global_norm,max_norm)W(i)max_norm

上面的三种方法实际上只是从不同角度出发的裁减,目的都是一样的防止WWW过大导致梯度爆炸。

1.2 梯度弥散

对于梯度弥散现象,可以通过增加学习率、减少网络深度、添加SKip Connection(跳接,不了解可以看看unet)等一系列措施抑制。

1.3 短时记忆

上述两个问题必然会导致RNN的短时记忆,那么接下来就是来看LSTM是怎么解决这些问题的,我们先介绍门控制,再对门控制进行组合成为LSTM。

2 门控机制

实际上门这个概念很好理解,不管是电路、生物还是电脑的最底层理论里无外乎都是这些东西,那什么是门,通俗的理解就是有的东西能过去有的东西过不去,它对信号也好,化学物质也好做了筛选,实际上LSTM中的门控也是这样的。
在这里插入图片描述

LSTM的门控机制如上图所示,这个图实际上就表明了输出o=输入x∗门控值g输出o=输入x*门控值g输出o=输入x门控值g,门控制g∈(0,1)g\in (0,1)g0,1,显然g=0g=0g=0表示门关闭输入完全没有进来,g=1g=1g=1时刚好相反。

这个理念很好理解,但这里存在一个问题,我们的大脑对自动根据环境信息判断当前的信息要不要接收或者接受多少,这个东西就是所谓的门控值ggg,那么在LSTM中这个门控值ggg怎么计算呢?

实际上也很简单,我们也根据现在输入的环境信息获取一个取值范围在0到1之间的值就可以了。

LSTM有两个很重要的变量一个是输出hth_tht,一个是状态ctc_tct

2.1 遗忘门

在这里插入图片描述
如上图所示实际上遗忘门就是对过去的状态ct−1c_{t-1}ct1做筛选,而该门的门控值是通过ht−1、xth_{t-1}、x_tht1xt得到的,而门控值的取值范围是(0,1)(0,1)(0,1),因此最合理的方式是采取sigmoidsigmoidsigmoid函数,即ft=sigmoid(Wf[ht−1,xt]+bf)f_t=sigmoid(W_f[h_{t-1},x_t]+b_f)ft=sigmoid(Wf[ht1,xt]+bf),经过该遗忘门后状态向量ct−1c_{t-1}ct1变为ft∗ct−1f_t*c_{t-1}ftct1

2.2 输入门

在这里插入图片描述
上图中的蓝色虚线部分就是输入门的部分,输入门的门控值依然是通过ht−1、xth_{t-1}、x_tht1xt得到的,即it=sigmoid(Wi[ht−1,xt]+bi)i_t=sigmoid(W_i[h_{t-1},x_t]+b_i)it=sigmoid(Wi[ht1,xt]+bi),而输入门要过滤的值同样与输入相关,ct~=tanh(Wc[ht−1,xt]+bc)\tilde{c_t}=tanh(W_c[h_{t-1},x_t]+b_c)ct~=tanh(Wc[ht1,xt]+bc),该值经过输入门后变为it∗ct~i_t*\tilde{c_t}itct~

将输入门的结果与遗忘门的结果相加得到的就是新的状态向量ct=ft∗ct−1+it∗ct~c_t=f_t*c_{t-1}+i_t*\tilde{c_t}ct=ftct1+itct~

2.3 输出门

在这里插入图片描述
输出门的门控值依然是通过ht−1、xth_{t-1}、x_tht1xt得到的,即ot=sigmoid(Wo[ht−1,xt]+bo)o_t=sigmoid(W_o[h_{t-1},x_t]+b_o)ot=sigmoid(Wo[ht1,xt]+bo),而输出门要过滤的值是tanh(ct)tanh(c_t)tanh(ct),所以输出ht=ot∗tanh(ct)h_t=o_t*tanh(c_t)ht=ottanh(ct)

2.4 LSTM解决梯度爆炸及梯度弥散的方法

实际上我们通过简单的推理就能知道:ctc0≈Πj=1tfi\frac{c_t}{c_0}\approx \Pi^t_{j=1}f_ic0ctΠj=1tfi,其中fif_ifi是门控制,它的取值范围在(0,1)(0,1)(0,1)之间,实际上来说这也是一种裁减方式。fk​<1f_k​<1fk<1的约束避免了梯度爆炸。

深究RNN 我们会发现实际上导致梯度弥散的本质是激活函数求导造成的。

正向传播:

ht=σ(W⋅[ht−1,xt]+b)h_t=σ(W⋅[h_{t−1},x_t]+b)ht=σ(W[ht1,xt]+b)

σσσ 是激活函数(如 tanhtanhtanhsigmoidsigmoidsigmoid

反向传播(关键路径):

损失函数 LLLht−kh_{t−k}htk​ 的梯度依赖于链式法则:

∂L∂ht−k=∂L∂ht⋅(Πj=t−k+1t∂hj∂hj−1)\frac{\partial L}{\partial h_{t-k}}=\frac{\partial L}{\partial h_t}\cdot (\Pi^t_{j=t-k+1}\frac{\partial h_j}{\partial h_{j-1}})htkL=htL(Πj=tk+1thj1hj)

Πj=t−k+1t∂hj∂hj−1\Pi^t_{j=t-k+1}\frac{\partial h_j}{\partial h_{j-1}}Πj=tk+1thj1hj部分的连乘是导致梯度弥散的关键。

其中∂hj∂hj−1=diag(σ′(zj))⋅Whh\frac{\partial h_j}{\partial h_{j-1}}=diag(\sigma'(z_j))\cdot W_{hh}hj1hj=diag(σ(zj))Whh

其中σ′(zj)\sigma'(z_j)σ(zj)是激活函数的导数,其值远小于1,所以就算WhhW_hhWhh特征值接近于1,连乘还是会导致梯度弥散。

但是LSTM中实际上梯度ctc0≈Πj=1tfi\frac{c_t}{c_0}\approx \Pi^t_{j=1}f_ic0ctΠj=1tfi只与门控值相关,没有激活函数的导数,从而及大程度的避免了梯度弥散的出现。

http://www.dtcms.com/a/324168.html

相关文章:

  • 【CF】Day124——杂题 (鸽巢原理 | 构造 | 贪心 + 模拟)
  • Excel常用功能函数
  • vue3-基础语法
  • 开启单片机
  • jvm有哪些垃圾回收器,实际中如何选择?
  • 【FPGA】初识FPGA
  • Git 版本管理规范与最佳实践摘要
  • 后量子密码学的迁移与安全保障:迎接量子时代的挑战
  • 【鸿蒙/OpenHarmony/NDK】C/C++开发教程之环境搭建
  • Linux操作系统从入门到实战(十八)在Linux里面怎么查看进程
  • HarmonyOS NEXT系列之编译三方C/C++库
  • 人工智能-python-机器学习-决策树与集成学习:决策树分类与随机森林
  • 给AI装上“翻译聚光灯”:注意力机制的机器翻译革命
  • ECharts Y轴5等分终极解决方案 - 动态适配缩放场景
  • 【走进Docker的世界】Docker的发展历程
  • MyBatis-Plus 逻辑删除
  • Spark学习(Pyspark)
  • Shell脚本-了解i++和++i
  • wordpress语言包制作工具
  • 点击速度测试:一款放大操作差距的互动挑战游戏
  • 简要介绍交叉编译工具arm-none-eabi、arm-linux-gnueabi与arm-linux-gnueabihf
  • 面向高级负载的 Kubernetes 调度框架对比分析:Volcano、YuniKorn、Kueue 与 Koordinator
  • Z20K118库中寄存器及其库函数封装-PMU库
  • ThreadLocal有哪些内存泄露问题,如何避免?
  • 机器学习实战·第三章 分类(1)
  • SAP HCM 结构化授权函数
  • 计算机网络:路由聚合是手动还是自动完成的?
  • 采用GPT5自动规划实现番茄计时器,极简提示词,效果达到产品级
  • 算术运算符指南
  • 震动马达实现库函数版(STC8)