当前位置: 首页 > news >正文

传统RNN模型

一.前言

上一章节介绍了一下RNN模型,而本章节主要是要了解传统RNN的内部结构及计算公式,掌握Pytorch中传统RNN⼯具的使⽤.以及了解传统RNN的优势与缺点.

二.RNN结构分析

 

结构解释图:

 

内部结构分析:

我们把⽬光集中在中间的⽅块部分, 它的输⼊有两部分, 分别是h(t-1)以及x(t), 代表上⼀时间步的隐 层输出, 以及此时间步的输⼊, 它们进⼊RNN结构体后, 会"融合"到⼀起, 这种融合我们根据结构解 释可知, 是将⼆者进⾏拼接, 形成新的张量[x(t), h(t-1)], 之后这个新的张量将通过⼀个全连接层(线 性层), 该层使⽤tanh作为激活函数, 最终得到该时间步的输出h(t), 它将作为下⼀个时间步的输⼊ 和x(t+1)⼀起进⼊结构体. 以此类推. 

内部结构过程演示:

 

根据结构分析得出内部计算公式:

 

tips:W_t [x_t, h_{t-1}] 就是代码中的nn.Linear(N, M), 相当于神经网络中的一个全连接层!!!

代码nn.Linear(N, M) 相当于矩阵W的维度是 [N, M]  

激活函数tanh的作⽤:

⽤于帮助调节流经⽹络的值, tanh函数将值压缩在-1和1之间. 

三.使用Pytorch构建RNN模型 

位置: 在torch.nn⼯具包之中, 通过torch.nn.RNN可调⽤

nn.RNN类初始化主要参数解释:

input_size: 输⼊张量x中特征维度的⼤⼩

hidden_size: 隐层张量h中特征维度的⼤⼩

num_layers: 隐含层的数量

nonlinearity: 激活函数的选择, 默认是tanh 

nn.RNN类实例化对象主要参数解释:

input: 输⼊张量x

h0: 初始化的隐层张量h 

 nn.RNN使⽤示例:

import torch
import torch.nn as nn# 输入维度5,隐藏层维度6,1层
rnn = nn.RNN(5, 6, 1)# 输入数据:形状 (seq_len=1, batch_size=3, input_size=5)
input = torch.randn(1, 3, 5)#(num_layers=1, batch_size=3, hidden_size=6)
h0 = torch.randn(1, 3, 6)
output, hn = rnn(input, h0)print(output)print(hn)

 

当然这些是要相等的,大家可以理解一下 

四. 传统RNN优缺点

传统RNN的优势 :

由于内部结构简单, 对计算资源要求低, 相⽐之后我们要学习的RNN变体:LSTM和GRU模型参数总量少 了很多, 在短序列任务上性能和效果都表现优异. 

传统RNN的缺点: 

传统RNN在解决⻓序列之间的关联时, 通过实践,证明经典RNN表现很差, 原因是在进⾏反向传播的时 候, 过⻓的序列导致梯度的计算异常, 发⽣梯度消失或爆炸.

梯度消失或爆炸介绍:

根据反向传播算法和链式法则, 梯度的计算可以简化为以下公式

其中sigmoid的导数值域是固定的, 在[0, 0.25]之间, ⽽⼀旦公式中的w也⼩于1, 那么通过这样的公式连 乘后, 最终的梯度就会变得⾮常⾮常⼩, 这种现象称作梯度消失. 反之, 如果我们⼈为的增⼤w的值, 使其 ⼤于1, 那么连乘够就可能造成梯度过⼤, 称作梯度爆炸.

梯度消失或爆炸的危害:

如果在训练过程中发⽣了梯度消失,权重⽆法被更新,最终导致训练失败; 梯度爆炸所带来的梯 度过⼤,⼤幅度更新⽹络参数,在极端情况下,结果会溢出(NaN值). 

五.总结 

今天我们学习了传统rnn模型的具体代码实现以及优缺点,这里就是不给大家总结了,内容也比较少,期待大家的点赞关注加收藏。 

 

 

 

http://www.dtcms.com/a/292254.html

相关文章:

  • NLP自然语言处理的一些疑点整理
  • 【CVPR 2025】即插即用DarkIR, 频域-空间协同的高效暗光恢复!
  • 深度学习 ---参数初始化以及损失函数
  • 从0到1学Pandas(一):Pandas 基础入门
  • Mixed Content错误:“mixed block“ 问题
  • React + ts 中应用 Web Work 中集成 WebSocket
  • linux初识网络及UDP简单程序
  • 2025年母单脱焦虑计划:社交恐惧者的塔罗赋能训练营
  • leetcode 1695. 删除子数组的最大得分 中等
  • 二分查找-852.山峰数组的峰顶索引-力扣(LeetCode)
  • 力扣 hot100 Day52
  • LeetCode 633.平方数之和
  • XML高效处理类 - 专为Office文档XML处理优化
  • Mysql-场景篇-2-线上高频访问的Mysql表,如何在线修改表结构影响最小?-1--Mysql8.0版本后的INSTANT DDL方案(推荐)
  • 【MySQL】MySQL基本概念
  • NISP-PTE基础实操——命令执行
  • MySQL高可用主从复制原理及常见问题
  • mysql_innodb_cluster_metadata源数据库
  • n1 armbian docker compose 部署aipan mysql
  • 板凳-------Mysql cookbook学习 (十二--------5)
  • vue3实现高性能pdf预览器功能可行性方案及实践(pdfjs-dist5.x插件使用及自定义修改)
  • Redis高级篇之最佳实践
  • VUE 中父级组件使用JSON.stringify 序列化子组件传递循环引用错误
  • TDengine时序数据库 详解
  • 扣子Coze智能体实战:自动化拆解抖音对标账号,输出完整分析报告(喂饭级教程)
  • STM32-SPI全双工同步通信
  • 什么是分布式事务,分布式事务的解决方案有哪些?
  • PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环
  • 自编码器表征学习:重构误差与隐空间拓扑结构的深度解析
  • vue2.0 + elementui + i18n:实现多语言功能