当前位置: 首页 > news >正文

TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

GRU(门控循环单元,Gated Recurrent Unit)是一种用于处理序列数据的递归神经网络(RNN)模型。它是为了克服传统RNN在长时间序列中训练时遇到的梯度消失问题(即记忆力衰减)而提出的。GRU相较于LSTM(长短时记忆网络),结构更简单,计算效率更高,但能够实现类似的性能。

GRU的工作原理

GRU通过引入两个重要的门控机制来控制信息的流动:

  1. 更新门(Update Gate):决定了当前时刻的隐藏状态有多少部分应该被更新,多少部分保留旧的隐藏状态信息。这个门类似于LSTM中的遗忘门和输入门的结合。它通过一个Sigmoid激活函数控制当前输入和上一时刻的隐藏状态对当前时刻的影响。

  2. 重置门(Reset Gate):控制遗忘多少旧的隐藏状态信息。它决定了当前时刻输入和过去隐藏状态的结合程度,从而决定了网络保留多少历史信息。通过一个Sigmoid激活函数来计算。

GRU的关键是利用这些门控机制在时间步之间传递信息,从而可以更好地捕获序列数据中的长期依赖性。

tf.keras.layers.GRU(units,activation='tanh',recurrent_activation='sigmoid',use_bias=True,kernel_initializer='glorot_uniform',recurrent_initializer='orthogonal',bias_initializer='zeros',dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,go_backwards=False,stateful=False,time_major=False,unroll=False,reset_after=True,**kwargs
)

核心参数:

  1. units - 核心参数

  • 作用:定义GRU层中隐藏单元的数量

  • 影响:决定模型的容量和复杂度,值越大表示记忆能力越强

  • 选择:通常根据任务复杂度在32-512之间选择

  1. activation - 激活函数

  • 作用:定义输出计算的激活函数

  • 默认值'tanh'

  • 功能:控制候选隐藏状态的计算

  1. recurrent_activation - 循环激活函数

  • 作用:定义门控机制的激活函数

  • 默认值'sigmoid'

  • 功能:控制更新门和重置门的计算

  1. return_sequences - 输出控制

  • 作用:控制是否返回所有时间步的输出

  • 默认值False(只返回最后一个时间步)

  • 应用

    • False:用于序列分类任务

    • True:用于序列标注或多层GRU堆叠

  1. return_state - 状态返回

  • 作用:是否返回最后的隐藏状态

  • 使用场景:编码器-解码器架构或需要状态传递的场景

  1. dropout - 输入丢弃率

  • 作用:输入单元的随机丢弃比例,防止过拟合

  • 范围:0.0-1.0,推荐0.2-0.5

  1. recurrent_dropout - 循环丢弃率

  • 作用:循环连接的随机丢弃比例

  • 功能:专门防止循环连接上的过拟合

  1. reset_after - 重置门位置

  • 作用:控制重置门的应用顺序

  • 默认值True(与CuDNN兼容的模式)

  • 影响:影响计算图和性能,通常保持默认

示例:

import tensorflow as tf
from keras import Input, layers
from keras.src.utils import pad_sequences
​
# 1. 加载 IMDB 数据集
max_features = 10000  # 使用词汇表中前 10000 个常见单词
maxlen = 100  # 每条评论的最大长度
​
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
print(x_train.shape, x_test.shape)
print(x_train[0])
print(y_train)
​
# 2. 数据预处理:对每条评论进行填充,使其长度统一
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test = pad_sequences(x_test, maxlen=maxlen)
​
# 3. 构建 RNN 模型
model = tf.keras.models.Sequential([Input(shape=(maxlen,)),layers.Embedding(input_dim=max_features, output_dim=128),  # 嵌入层,将单词索引映射为向量 output_dim  嵌入向量的维度(即每个输入词的嵌入表示的长度)layers.GRU(units=64, dropout=0.2, recurrent_dropout=0.2),# GRU 层:包含 64 个神经元,激活函数默认使用 tanh  dropout表示在每个时间步上丢弃20% recurrent_dropout 递归状态(即隐藏状态)的dropout比率为20%layers.Dense(1, activation='sigmoid')  # 输出层:用于二分类(正面或负面),激活函数为 sigmoid
])
​
# 4. 模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
​
# 5. 模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6. 模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

运行结果:

http://www.dtcms.com/a/502887.html

相关文章:

  • TVM | Relay
  • 使用 Conda 安装 QGIS 也是很好的安装方式
  • 网站套餐到期什么意思抖音seo优化系统招商
  • 怎么看网站pr值衡水市住房和城乡建设局网站
  • 散点拟合圆:Matlab两种方法实现散点拟合圆
  • Kubernetes流量管理:从Ingress到GatewayAPI演进
  • 专做品牌网站西安做网站电话
  • “函数恒大于0”说明函数是可取各不同数值的变数(变量)——“函数是一种对应法则等”是非常明显的错误
  • Linux系统--信号(4--信号捕捉、信号递达)--重点--重点!!!
  • Blender后期合成特效资产预设插件 MP_Comp V2.0.2
  • 达梦8数据库常见故障分析与解决方案
  • 迁移服务器
  • 解决docker构建centos7时yum命令报错、镜像源失效问题
  • 密钥轮换:HashiCorp Vault自动续期,密钥生命周期?
  • 即时通讯系统核心模块实现
  • 【HarmonyOS】组件嵌套优化
  • 福州企业做网站催眠物语wordpress
  • 图文并茂:全面了解UART相关知识(TTL+RS232+RS484)
  • VMware Euler系统Ctrl+C/V共享剪贴板完全指南:从配置到彻底清理
  • IOT项目——STM32
  • 【物联网架构】
  • 【编程】IDEA自定义系统注解格式|自定义自定义注解格式
  • 定位网站关键词dw网页制作模板源代码
  • 【Linux网络】封装Socket
  • Solidity智能合约开发入门攻略
  • AI决策系统:从数据到行动的智能跃迁——底层逻辑与实践全景解析
  • 好看的单页面网站石岩网站设计
  • 未来的 AI 操作系统(二)——世界即界面:自然语言成为新的人机交互协议
  • 经典排序算法的实现与解析
  • 流量转化与生态重构:“开源AI智能名片链动2+1模式S2B2C商城小程序”对直播电商的范式革新