当前位置: 首页 > news >正文

TensorFlow2 Python深度学习 - 循环神经网络(LSTM)示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 循环神经网络(LSTM)示例

LSTM(长短期记忆网络,Long Short-Term Memory)是一种特殊类型的循环神经网络(RNN),用于处理和预测序列数据。它能够有效地解决标准RNN在长期依赖问题中的缺点,如梯度消失和梯度爆炸问题。LSTM的关键在于其特殊的结构,其中包括了三个“门”机制:输入门、遗忘门和输出门,这些门控制信息流的进入、遗忘和输出,允许模型更好地捕捉和保持长期的依赖关系。

LSTM的基本结构

LSTM单元的结构包括以下几部分:

  1. 输入门(Input Gate):决定哪些新信息被写入到单元状态。

  2. 遗忘门(Forget Gate):决定哪些信息会从单元状态中丢弃或保留。

  3. 输出门(Output Gate):决定哪些信息将用于输出。

tf.keras.layers.LSTM(units,activation='tanh',recurrent_activation='sigmoid',use_bias=True,kernel_initializer='glorot_uniform',recurrent_initializer='orthogonal',bias_initializer='zeros',unit_forget_bias=True,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,go_backwards=False,stateful=False,time_major=False,unroll=False,**kwargs
)

核心参数:

  1. units - 最重要的参数

  • 作用:定义LSTM层中记忆单元的数量

  • 通俗理解:LSTM的"脑容量"或"记忆力大小"

  • 影响:值越大,模型表达能力越强,但计算复杂度越高

  • 建议范围:32-1024,根据任务复杂度选择

  1. return_sequences - 输出控制

  • 作用:控制是否返回所有时间步的输出

  • 默认值False(只返回最后一个时间步的输出)

  • 使用场景

    • False:用于分类、情感分析等只需要最终结果的场景

    • True:用于序列标注、机器翻译等需要每个时间步输出的场景

  1. dropoutrecurrent_dropout - 正则化参数

  • dropout:输入单元的丢弃率,防止过拟合

  • recurrent_dropout:循环连接的丢弃率,防止循环过拟合

  • 建议值:0.2-0.5,根据数据量和模型复杂度调整

  1. activationrecurrent_activation - 激活函数

  • activation:主要计算的激活函数,默认'tanh'

  • recurrent_activation:门控单元的激活函数,默认'sigmoid'

  1. return_state - 状态返回

  • 作用:是否返回LSTM的隐藏状态和细胞状态

  • 使用场景:编码器-解码器结构、状态传递等高级应用

  1. stateful - 状态保持

  • 作用:批次间是否保持LSTM状态

  • 使用场景:处理超长序列需要分批时保持状态连续性

  1. unroll - 展开计算

  • 作用:是否将RNN展开为前馈网络

  • 优点:加速训练(适合短序列)

  • 缺点:内存消耗大(不适合长序列)

示例:

import tensorflow as tf
from keras import Input, layers
from keras.src.utils import pad_sequences
​
# 1. 加载 IMDB 数据集
max_features = 10000  # 使用词汇表中前 10000 个常见单词
maxlen = 100  # 每条评论的最大长度
​
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
print(x_train.shape, x_test.shape)
print(x_train[0])
print(y_train)
​
# 2. 数据预处理:对每条评论进行填充,使其长度统一
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test = pad_sequences(x_test, maxlen=maxlen)
​
# 3. 构建 RNN 模型
model = tf.keras.models.Sequential([Input(shape=(maxlen,)),layers.Embedding(input_dim=max_features, output_dim=128),  # 嵌入层,将单词索引映射为向量 output_dim  嵌入向量的维度(即每个输入词的嵌入表示的长度)layers.LSTM(units=64, dropout=0.2, recurrent_dropout=0.2),# LSTM 层:包含 64 个神经元,激活函数默认使用 tanh  dropout表示在每个时间步上丢弃20% recurrent_dropout 递归状态(即隐藏状态)的dropout比率为20%layers.Dense(1, activation='sigmoid')  # 输出层:用于二分类(正面或负面),激活函数为 sigmoid
])
​
# 4. 模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
​
# 5. 模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6. 模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

运行结果:

http://www.dtcms.com/a/499274.html

相关文章:

  • C++第二十三课:猜数字游戏等练习
  • 河南省建设厅网站中州杯企业网站推广怎么做
  • 【数论】最大公因数 (gcd) 与最小公倍数 (lcm)
  • rocky linux MariaDB安装过程
  • git的 Rebase
  • 第8篇 QT联合halcon12在vs2019搭建环境开发图像处理
  • 【小白笔记】最大交换 (Maximum Swap)问题
  • CentOS安装Node.js
  • 深入解析MCP:从基础配置到高级应用指南
  • 佛山网站建设服务wordpress 不能更换主题
  • Process Monitor 学习笔记(5.13):从 0 到 1 的排障剧本清单(可复用模板)
  • Fluent 重叠网格+UDF NACA0012翼型摆动气动仿真
  • 深圳网站建设 设计卓越迈wordpress一键采集文章
  • 理想汽车Java后台开发面试题及参考答案(下)
  • python|if判断语法对比
  • 全链路智能运维中的实时流处理架构与状态管理技术
  • 排序算法:详解快速排序
  • 安阳哪里做360网站科技感十足的网站
  • UV 紫外相机在半导体制造领域的应用
  • 突破亚微米光电子器件制造瓶颈!配体交换辅助打印技术实现全打印红外探测器
  • 可见光工业相机半导体制造领域中的应用
  • require和 import是两种不同的模块引入方式的区别
  • 半导体制造工艺基本认识 五 薄膜沉积
  • 矩阵及其应用
  • **发散创新:探索零信任网络下的安全编程实践**随着信息技术的飞速发展,网络安全问题日益凸显。传统的网络安全防护方式已难以
  • 网络营销方案毕业设计安卓手机性能优化软件
  • 建设企业网站价格建设银行北京市财满街分行网站
  • (Kotlin高级特性一)kotlin的扩展函数和属性在字节码层面是如何实现的
  • Spring Boot 3零基础教程,WEB 开发 静态资源默认配置 笔记27
  • 【论文精度-2】求解车辆路径问题的神经组合优化算法:综合展望(Yubin Xiao,2025)