当前位置: 首页 > news >正文

NLP:LSTM和GRU分享

本文目录:

  • 一、 前置知识:RNN的痛点
  • 二、LSTM(长短期记忆网络)
    • (一)遗忘门(Forget Gate)
    • (二)输入门(Input Gate)
    • (三)输出门(Output Gate)
    • (四)细胞状态(Cell State)
    • (五)使用Pytorch构建LSTM模型
  • 三、 GRU(门控循环单元)
    • (一)更新门(Update Gate)
    • (二)重置门(Reset Gate)
    • (三)使用Pytorch构建GRU模型
  • 文末附赠:
    • (一)传统RNN、 LSTM 和 GRU 三者核心结构对比
    • (二)传统RNN、 LSTM 和 GRU 三者性能表现对比
    • (三)传统RNN、 LSTM 和 GRU 三者应用场景对比

前言:前面文章分享了传统RNN,此次分享传统RNN变体:LSTM和GRU。

一、 前置知识:RNN的痛点

传统RNN像金鱼记忆:

只能记住最近几步的信息(梯度消失/爆炸)

遇到长序列就懵圈(“开头说了啥来着?”)

二、LSTM(长短期记忆网络)

在这里插入图片描述
核心设计:记忆管控大师

想象LSTM是个图书馆管理员,它有三把钥匙

(一)遗忘门(Forget Gate)

决定哪些旧记忆该丢弃:
“上个月的天气预报数据?可以忘了。”

(二)输入门(Input Gate)

决定哪些新信息值得记录:
“今天突然下冰雹?这个得重点记!”

(三)输出门(Output Gate)

决定当前输出什么信息:
“根据天气记录,建议你带伞。”

(四)细胞状态(Cell State)

像传送带,专门运输长期记忆。

公式精华

f_t = σ(W_f · [h_{t-1}, x_t] + b_f)  # 遗忘门  
i_t = σ(W_i · [h_{t-1}, x_t] + b_i)   # 输入门  
C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C) # 候选记忆  
C_t = f_t * C_{t-1} + i_t * C̃_t        # 更新细胞状态  
o_t = σ(W_o · [h_{t-1}, x_t] + b_o)    # 输出门  
h_t = o_t * tanh(C_t)                  # 最终输出  

(五)使用Pytorch构建LSTM模型

位置: 在torch.nn工具包之中, 通过torch.nn.LSTM可调用。

nn.LSTM类初始化主要参数解释: input_size: 输入张量x中特征维度的大小。

hidden_size: 隐层张量h中特征维度的大小., num_layers: 隐含层的数量。

bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用。

nn.LSTM类实例化对象主要参数解释: input: 输入张量x、h0: 初始化的隐层张量h、c0: 初始化的细胞状态张量c。

nn.LSTM使用示例:

# 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)>>> import torch.nn as nn
>>> import torch
>>> rnn = nn.LSTM(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> c0 = torch.randn(2, 3, 6)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],[ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],[-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364,  0.0645, -0.3996, -0.0500, -0.0152],[ 0.3852,  0.0704,  0.2103, -0.2524,  0.0243,  0.0477],[ 0.2571,  0.0608,  0.2322,  0.1815, -0.0513, -0.0291]],[[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],[ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],[-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500,  0.1009, -0.5806, -0.0668, -0.1161],[ 0.7438,  0.0957,  0.5509, -0.7725,  0.0824,  0.0626],[ 0.3131,  0.0920,  0.8359,  0.9187, -0.4826, -0.0717]],[[ 0.1240, -0.0526,  0.3035,  0.1099,  0.5915,  0.0828],[ 0.0203,  0.8367,  0.9832, -0.4454,  0.3917, -0.1983],[-0.2976,  0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],grad_fn=<StackBackward>)
特别分享:什么是Bi-LSTM ?- Bi-LSTM即双向LSTM, 它没有改变LSTM本身任何的内部结构, 只是将LSTM应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出。

三、 GRU(门控循环单元)

在这里插入图片描述

核心设计LSTM的极简版

GRU像效率至上的程序员,合并了LSTM的门:

(一)更新门(Update Gate)

二合一:同时控制遗忘和输入
“旧记忆留多少?新记忆收多少?这门说了算!”

(二)重置门(Reset Gate)

决定多少过去信息用于计算新状态:

“昨天的股票数据对预测今天有用吗?”

独特优势:

参数比LSTM少(训练更快)

公式精华:

z_t = σ(W_z · [h_{t-1}, x_t])         # 更新门  
r_t = σ(W_r · [h_{t-1}, x_t])         # 重置门  
h̃_t = tanh(W · [r_t * h_{t-1}, x_t])  # 候选状态  
h_t = (1-z_t) * h_{t-1} + z_t * h̃_t   # 最终状态  

(三)使用Pytorch构建GRU模型

位置: 在torch.nn工具包之中, 通过torch.nn.GRU可调用。

nn.GRU类初始化主要参数解释: input_size: 输入张量x中特征维度的大小。

hidden_size: 隐层张量h中特征维度的大小, num_layers: 隐含层的数量。

bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用。

nn.GRU类实例化对象主要参数解释: * input: 输入张量x. * h0: 初始化的隐层张量h。

nn.GRU使用示例:

>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.GRU(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> output, hn = rnn(input, h0)
>>> output
tensor([[[-0.2097, -2.2225,  0.6204, -0.1745, -0.1749, -0.0460],[-0.3820,  0.0465, -0.4798,  0.6837, -0.7894,  0.5173],[-0.0184, -0.2758,  1.2482,  0.5514, -0.9165, -0.6667]]],grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.6578, -0.4226, -0.2129, -0.3785,  0.5070,  0.4338],[-0.5072,  0.5948,  0.8083,  0.4618,  0.1629, -0.1591],[ 0.2430, -0.4981,  0.3846, -0.4252,  0.7191,  0.5420]],[[-0.2097, -2.2225,  0.6204, -0.1745, -0.1749, -0.0460],[-0.3820,  0.0465, -0.4798,  0.6837, -0.7894,  0.5173],[-0.0184, -0.2758,  1.2482,  0.5514, -0.9165, -0.6667]]],grad_fn=<StackBackward>)
特别分享:什么是Bi-GRU ?Bi-GRU与Bi-LSTM的逻辑相同, 都是不改变其内部结构, 而是将模型应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出。

文末附赠:

(一)传统RNN、 LSTM 和 GRU 三者核心结构对比

在这里插入图片描述

(二)传统RNN、 LSTM 和 GRU 三者性能表现对比

在这里插入图片描述

(三)传统RNN、 LSTM 和 GRU 三者应用场景对比

在这里插入图片描述

今天的分享到此结束(改用了一种更温馨的文体~希望大家喜欢(╹▽╹)

http://www.dtcms.com/a/282438.html

相关文章:

  • 加速度传感器的用途与应用
  • Opencv---cv::minMaxLoc函数
  • Go与Python在数据管道与分析项目中的抉择:性能与灵活性的较量
  • React 中 props 的最常用用法精选+useContext
  • 单列集合顶层接口Collection
  • QT——事件系统详解
  • YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架(原创创新算法)
  • GaussDB 数据库架构师修炼(五) 存储容量评估
  • 动态规划题解_打家劫舍【LeetCode】
  • MySQL 8.0 OCP 1Z0-908 题目解析(27)
  • 钱包核心标准 BIP32、BIP39、BIP44:从助记词到多链钱包的底层逻辑
  • RocketMQ源码级实现原理-消息过滤与重试
  • 【Deepseek-R1+阿里千问大模型】四步完成本地调用本地部署大模型和线上大模型,实现可视化使用
  • 拥抱主权AI:OpenCSG驱动智能体运营,共筑新加坡智能高地
  • 【技术追踪】基于检测器引导的对抗性扩散攻击器实现定向假阳性合成——提升息肉检测的鲁棒性(MICCAI-2025)
  • 辅助驾驶GNSS高精度模块UM680A外形尺寸及上电与下电
  • 剑指offer64_圆圈中最后剩下的数字
  • 为什么要用erc165识别erc721或erc1155
  • 系统性学习C语言-第十八讲-C语言内存函数
  • IIS-网站报500.19错误代码0x8007000d问题解决
  • LeetCode Hot100【4. 寻找两个正序数组的中位数】
  • 什么是 WebClient?
  • xss-labs的小练
  • 基于faster-r-cnn行人检测和ResNet50+FPN的可见光红外图像多模态算法融合创新
  • VIVADO技巧_BUFGMUX时序优化
  • 比特币技术简史 第二章:密码学基础 - 哈希函数、公钥密码学与数字签名
  • 基于阿里云云服务器-局域网组网软件
  • Mfc初始化顺序
  • 【27】MFC入门到精通——MFC 修改用户界面登录IP IP Address Control
  • 虚幻引擎5 GAS开发俯视角RPG游戏 #06-7:无限游戏效果