当前位置: 首页 > news >正文

RNN模型及NLP应用(5/9)——多层RNN、双向RNN、预训练

声明:

       本文基于哔站博主【Shusenwang】的视频课程【RNN模型及NLP应用】,结合自身的理解所作,旨在帮助大家了解学习NLP自然语言处理基础知识。配合着视频课程学习效果更佳。

材料来源:【Shusenwang】的视频课程【RNN模型及NLP应用】

视频链接:

RNN模型与NLP应用(5/9):多层RNN、双向RNN、预训练_哔哩哔哩_bilibili


一、学习目标

1.了解什么是多层RNN、双向RNN、预训练

2.学会代码实现多层RNN、双向RNN、预训练

3.2.清楚这三种RNN模型的底层逻辑


 二、多层RNN、双向RNN、预训练

1.多层RNN

将RNN堆叠起来构成深度RNN神经网络,以下为示意图:

神经网络每一步都会更新状态h,每一个状态向量h都有两个copies,一个输出,一个用来传给下一节点。这一层输出的h,用于上一层的输入。以此类推,第二层的状态向量一个用来传给下一个时间点,另一个用来作为第三层的输入。

最后一层的ht是最终的输出。


代码实现:

       上图中红色方框标记处:用了三个LSTM层,第一层return_sequence=true,第一层LSTM的输出需要作为第二层的输入;第二层return_sequence=true,第二层LSTM的输出需要作为第三层的输入,而第三层return_sequence=flase,第三层LSTM的输出为最终输出。


2.双向RNN

训练两条RNN,一条从左往右输出,一条从右往左输出。

同一个输入X同时输给两条RNN,两条RNN互相独立,互不影响。

如果有多条RNN,那么将最上面的内一行的y作为输入再传给下一条RNN。

最后只取左右两边的ht和ht',如下图:

代码实现:

要实现双向RNN,则需要导入Bidirectional层,然后在标准的LSTM外套一个Bidirectional层即可


3.预训练

预训练在深度学习中非常常用,比如在训练卷积神经网络中,如果网络足够大,但数据集不够大,

这时候你就可以在Imagenet(大规模的注释图像数据库)上做预训练,这样可以让神经网络有更好的初始化,也可以避免overfitting(过拟合)

训练RNN的时候也是这个道理

       如上图所示,这个RNN模型的embedding层有320000个参数,那么在数据集很小的情况下,该模型很可能会产生过拟合。

那么RNN预训练具体是这么做的:

    【首先】要找到一个足够大的数据集,可以是情感芬妮下的数据集,也可以是其他类型的数据,但是任务最好是接近原来情感分析的任务,最好是学出来的词向量带有正面或负面的情感词。两个任务越相似,预训练出来的transform(表现)就会越好。这个神经网络的结构是什么样的都可以,只要他有embedding层就行。

     【其次】就是要在大数据集上训练这个神经网络,训练好后把上面的层丢掉,只保留eembedding层和训练好的模型参数

【最后】再搭建我们自己的RNN模型

       新的RNN层和全连接层的参数都是随机初始化的,而下面的embedding层是预训练出来的,要把embedding层固定住,不要训练embedding层


三、总结

①能用SimpleRNN的情况下,肯定可以用LSTM,LSTM的效果要比Simple RNN好,因此我们应该都用LSTM

②提升RNN效果的方式之一就是使用双向RNN,双向RNN比单向的训练效果好。

③多层RNN的容量比单层大,如果训练数据比较多,那么多层的RNN训练效果比单层的好

④RNN的embedding层中参数往往都很多,那么在数据集较小的情况下,训练可能会出现over fitting(过拟合),因为我们就需要在大数据集上先进行预训练。

http://www.dtcms.com/a/108175.html

相关文章:

  • js防抖函数防抖无效的解决方法
  • 14.网络套接字TCP
  • 5.好事多磨 -- TCP网络连接Ⅱ
  • LabVIEW多线程
  • API vs 网页抓取:获取数据的最佳方式
  • PyTorch中.pth文件的解析及应用
  • Linux的TCP连接数到达2万,其中tcp_tw、tcp_alloc、tcp_inuse都很高,可能出现什么问题
  • Python `async` 和 `asyncio` 区别; `asyncio.Lock` 和 `threading.Lock`区别
  • pyqt SQL Server 数据库查询-优化2
  • 使用ChromaDB构建RAG知识库
  • SSH远程连接服务器(cursor)
  • ssh私钥文件登录问题:Load key invalid format
  • spring-ai-alibaba第四章阿里dashscope集成百度翻译tool
  • 端到端机器学习流水线(MLflow跟踪实验)
  • Vue3+Vite+TypeScript+Element Plus开发-04.静态菜单设计
  • Java单例模式详解
  • 深入理解 CSS 选择器:从基础到高级的样式控制
  • iPhone 16怎么录制屏幕内容?屏幕录制技巧、软件分享
  • eBest AI智能报表:用自然语言对话解锁企业数据生产力
  • PostgreSQL HAVING 子句详解
  • 最小二乘求解器lstsq,处理带权重和L2正则的线性回归
  • Vue3 + Element Plus + AntV X6 实现拖拽树组件
  • 【人工智能之大模型】如何缓解大语言模型LLMs重复读的问题?
  • 函数ioctl(Input/Output Control)
  • mac如何将jar包上传到maven中央仓库中
  • LeetCode-695. 岛屿的最大面积
  • Linux系统之systemctl管理服务及编译安装配置文件安装实现systemctl管理服务
  • Redis-10.在Java中操作Redis-Spring Data Redis使用方式-操作步骤说明
  • 基于随机森林算法的信用风险评估项目
  • 汇编学习结语