当前位置: 首页 > news >正文

零基础-动手学深度学习-9.1. 门控循环单元(GRU)及代码实现

  前一章中我们介绍了循环神经网络的基础知识, 这种网络可以更好地处理序列数据。 我们在文本数据上实现了基于循环神经网络的语言模型, 但是对于当今各种各样的序列学习问题,这些技术可能并不够用。例如,循环神经网络在实践中一个常见问题是数值不稳定性。 尽管我们已经应用了梯度裁剪等技巧来缓解这个问题, 但是仍需要通过设计更复杂的序列模型来进一步处理它。 具体来说,我们将引入两个广泛使用的网络, 即门控循环单元(gated recurrent units,GRU)和 长短期记忆网络(long short-term memory,LSTM)。 然后,我们将基于一个单向隐藏层来扩展循环神经网络架构。 我们将描述具有多个隐藏层的深层架构, 并讨论基于前向和后向循环计算的双向设计。 现代循环网络经常采用这种扩展。 在解释这些循环神经网络的变体时, 我们将继续考虑 8节中的语言建模问题。事实上,语言建模只揭示了序列学习能力的冰山一角。 在各种序列学习问题中,如自动语音识别、文本到语音转换和机器翻译, 输入和输出都是任意长度的序列。 为了阐述如何拟合这种类型的数据, 我们将以机器翻译为例介绍基于循环神经网络的 “编码器-解码器”架构和束搜索,并用它们来生成序列

9.1. 门控循环单元(GRU)

在 8.7节中, 我们讨论了如何在循环神经网络中计算梯度, 以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。 下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。

  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。

  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示

在学术界已经提出了许多方法来解决这类问题。 其中最早的方法是“长短期记忆”(long-short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997), 我们将在 9.2节中讨论。 门控循环单元(gated recurrent unit,GRU) (Cho et al., 2014) 是一个稍微简化的变体,通常能够提供同等的效果, 并且计算 (Chung et al., 2014)的速度明显更快。 由于门控循环单元更简单,我们从它开始解读。

9.1.1. 门控隐状态

对于重置门和更新门的理解为各司其职。重置门单方面控制自某个节点开始,之前的记忆(隐状态)不在乎了,直接清空影响,同时也需要更新门帮助它实现记忆的更新。更新们更多是用于处理梯度消失问题,可以选择一定程度地保留记忆,防止梯度消失。

9.1.2. 从零开始实现

为了更好地理解门控循环单元模型,我们从零开始实现它。 首先,我们读取 8.5节中使用的时间机器数据集:

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

9.1.2.1. 初始化模型参数

下一步是初始化模型参数。 我们从标准差为的高斯分布中提取权重, 并将偏置项设为,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01
#前面两个def和rnn是一样的def three():#这个函数用来辅助定义两个w和一个breturn (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))
#是的就是方便在这个下面进行一个参数定义W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 对所有parmes附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

9.1.2.2. 定义模型

现在我们将定义隐状态的初始化函数init_gru_state。 与 8.5节中定义的init_rnn_state函数一样, 此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )#这里还是返回一个tuple 忘了为啥来着 回去看来看因为lstm需要另一个参数嘻嘻

现在我们准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

@是按矩阵乘法,*是按元素乘法的意思

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:#对每个序列的时间长度Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

9.1.2.3. 训练与预测

训练和预测的工作方式与 8.5节完全相同。 训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)#gru放进去
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)输出:
perplexity 1.1, 19911.5 tokens/sec on cuda:0
time traveller firenis i heidfile sook at i jomer and sugard are
travelleryou can show black is white by argument said filby

9.1.3. 简洁实现

高级API包含了前文介绍的所有配置细节, 所以我们可以直接实例化门控循环单元模型。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)#矩阵乘法做到一起所以速度快
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)输出:
perplexity 1.0, 109423.8 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi

9.1.4. 小结

  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。

  • 重置门有助于捕获序列中的短期依赖关系。

  • 更新门有助于捕获序列中的长期依赖关系。

  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

http://www.dtcms.com/a/319890.html

相关文章:

  • 解决 npm i node-sass@4.12.0 安装失败异常 npm i node-sass异常解决
  • 如何使用 pnpm创建Vue 3 项目
  • 玳瑁的嵌入式日记D14-0807(C语言)
  • 蓝凌EKP产品:列表查询性能优化全角度
  • C++引用专题(上):详解C++传值返回和传引用返回
  • JavaScript核心概念解析:从基础语法到对象应用
  • 部署 AddressSanitizer(ASan)定位内存泄漏、内存越界
  • Java+Vue合力开发固定资产条码管理系统,移动端+后台管理,集成资产录入、条码打印、实时盘点等功能,助力高效管理,附全量源码
  • 【保姆级喂饭教程】python基于mysql-connector-python的数据库操作通用封装类(连接池版)
  • SPI TFT全彩屏幕驱动开发及调试
  • Sentinel原理之责任链详解
  • imx6ull-驱动开发篇12——GPIO子系统驱动LED
  • C++高频知识点(十五)
  • Qwen-Image开源模型实战
  • 【Floyd】Shortest Routes II
  • 显卡服务器的作用主要是什么?-哈尔滨云前沿
  • 使用内网穿透工具1分钟上线本地网站至公网可访问,局域网电脑变为服务器
  • Mysql数据仓库备份脚本
  • 2.7 (拓展)非父子通信(事件总线和provide-inject)详解
  • 2025 年华数杯全国大学生数学建模竞赛B题 网络切片无线资源管理方案设计--完整成品、思路、代码、模型结果分享,仅供学习~
  • java 生成pdf导出
  • 【tip】font-family的设置可能导致的文字奇怪展示
  • 《P3275 [SCOI2011] 糖果》
  • 运营商面向政企客户推出的DICT项目
  • 【ee类保研面试】数学类---概率论
  • 5G专网提高产业生产力
  • 别墅泳池设计综述:从理念创新到技术实现的系统性研究
  • 基于 PyTorch 从零实现 Transformer 模型:从核心组件到训练推理全流程
  • Java 大视界 -- Java 大数据在智能安防门禁系统中的人员行为分析与异常事件预警(385)
  • nvm安装,nvm管理node版本