当前位置: 首页 > news >正文

RNN循环神经网络(一):基础RNN结构、双向RNN

RNN循环神经网络

什么是循环神经网络?

循环神经网络(Recurrent Neural Network, RNN)是一类专门用于处理序列数据的神经网络架构。与传统的前馈神经网络不同,RNN具有"记忆"能力,能够捕捉数据中的时间依赖关系。

核心特点:

  1. 循环连接:RNN单元之间存在循环连接,使得信息能够在网络内部持续传递
  2. 参数共享:相同的权重参数在时间步之间共享,大大减少了模型参数数量
  3. 序列处理:能够处理可变长度的输入序列,适用于时序数据

基本结构:

RNN的基本单元包含一个隐藏状态(hidden state),它在每个时间步都会被更新:

  • 新隐藏状态 = f(当前输入, 前一个隐藏状态)

举一个简单的例子:

在这里插入图片描述

简单的循环神经网络例子(多对多)

我们来做一个简单的循环神经网络,其实也就是跟上图一致。

import torch
from torch import nnclass RNNCell(nn.Module):def __init__(self,input_size,hidden_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.w_hidden = torch.randn(hidden_size,hidden_size)self.w_input = torch.randn(input_size,hidden_size)self.tanh = nn.Tanh()def forward(self,x,hidden_state=None):N,input_size = x.shapeif hidden_state is None:hidden_state = torch.zeros(N,self.hidden_size)hidden_state = self.tanh(hidden_state @ self.w_hidden + x @ self.w_input)return hidden_stateclass RNN(nn.Module):def __init__(self,input_size,hidden_size):super().__init__()self.cell = RNNCell(input_size,hidden_size)self.w_output = torch.randn(hidden_size,hidden_size)def forward(self,x,hidden_state=None):N,L,input_size = x.shapeoutputs = []for i in range(L):x_i = x[:,i]hidden_state = self.cell(x_i,hidden_state)out = hidden_state @ self.w_outputoutputs.append(out)outputs = torch.stack(outputs,dim=1)return outputs,hidden_stateif __name__ == "__main__":x = torch.randn(5,3,10)model = RNN(10,20)y,h = model(x)print(y.shape)print(h.shape)

双向循环神经网络

双向RNN其实也就是两层RNN的叠加,分别更新的是两层隐藏状态以及两层输出。

在这里插入图片描述

import torch
from torch import nnclass BiRNN(nn.Module):def __init__(self,input_size,hidden_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_size#前向RNN和线性层self.forward_cell = nn.RNNCell(input_size,hidden_size)self.backward_cell = nn.RNNCell(input_size,hidden_size)#反向RNN和线性层self.forward_Linear = nn.Linear(hidden_size,hidden_size)self.backward_Linear = nn.Linear(hidden_size,hidden_size)def forward(self,x,hidden = None):N,L,input_size = x.shapeif hidden is None:#堆叠两层隐藏层hidden = torch.zeros(2,N,self.hidden_size)h_forward = hidden[0]out_forward = []for i in range(L):h_forward = self.forward_cell(x[:,i],h_forward)out = self.forward_Linear(h_forward)out_forward.append(out)out_forward = torch.stack(out_forward,dim=1)x = torch.flip(x,dims=[1])h_backward = hidden[1]out_backward = []for i in range(L):h_backward = self.backward_cell(x[:,i],h_backward)out = self.backward_Linear(h_backward)out_backward.append(out)out_backward = torch.stack(out_backward,dim=1)outputs = torch.concat((out_forward,out_backward),dim=-1)hidden = torch.stack([h_forward,h_backward])return outputs,hiddenif __name__ == '__main__':x = torch.randn((5,3,10))model = BiRNN(10,20)outputs,hidden = model(x)print(outputs.shape)print(hidden.shape)

文章转载自:

http://EIvoSn8N.fqymm.cn
http://bzh15HWA.fqymm.cn
http://KEl8FCyh.fqymm.cn
http://27lzm241.fqymm.cn
http://eWnoKacz.fqymm.cn
http://PNxmIG6q.fqymm.cn
http://Hztaie4J.fqymm.cn
http://87tG9GEI.fqymm.cn
http://YXWBqD8c.fqymm.cn
http://xF4LT41c.fqymm.cn
http://AR6lvuHN.fqymm.cn
http://tuBKyCfp.fqymm.cn
http://ywCt79S0.fqymm.cn
http://s5AurRnQ.fqymm.cn
http://MSpEJ2E4.fqymm.cn
http://Rl1hNNyh.fqymm.cn
http://Eya6VtJ0.fqymm.cn
http://eNdX5r09.fqymm.cn
http://mmnlDBgh.fqymm.cn
http://529rEepX.fqymm.cn
http://GBjbQOmd.fqymm.cn
http://prcakSBH.fqymm.cn
http://dXBEDB7r.fqymm.cn
http://gmiGTGrx.fqymm.cn
http://LPaprzxn.fqymm.cn
http://5WF5unWX.fqymm.cn
http://jl7TVpjr.fqymm.cn
http://2REO4Aoi.fqymm.cn
http://p4fr5O1Q.fqymm.cn
http://s7tQ21Vi.fqymm.cn
http://www.dtcms.com/a/375678.html

相关文章:

  • 牛刀小试之设计模式
  • openCV3.0 C++ 学习笔记补充(自用 代码+注释)---持续更新 四(91-)
  • leetcode-python-1941检查是否所有字符出现次数相同
  • python内存分析memory_profiler简单应用
  • 9.9 json-server
  • excel中筛选条件,数字筛选和文本筛选相互转换
  • zsh: no matches found: /Users/xxx/.ssh/id_rsa*
  • 【EPGF 白皮书】路径治理驱动的多版本 Python 架构—— Windows 环境治理与 AI 教学开发体系
  • C语言面向对象编程:模拟实现封装、继承、多态
  • 设计 模式
  • 【Scientific Data 】紫茎泽兰的染色体水平基因组组装
  • MVCC-多版本并发控制
  • 【MybatisPlus】SpringBoot3整合MybatisPlus
  • 如何在FastAPI中玩转“时光倒流”的数据库事务回滚测试?
  • MySQL数据库面试题整理
  • PostgreSQL 大对象管理指南:pg_largeobject 从原理到实践
  • 传统项目管理的局限性有哪些
  • 内核函数:copy_process
  • 《UE5_C++多人TPS完整教程》学习笔记50 ——《P51 多人游戏中的俯仰角(Pitch in Multiplayer)》
  • RL【5】:Monte Carlo Learning
  • 深度解析HTTPS:从加密原理到SSL/TLS的演进之路
  • minio 文件批量下载
  • 【算法专题训练】19、哈希表
  • AJAX入门-URL、参数查询、案例查询
  • 安装ultralytics
  • Eino ChatModel 组件指南摘要
  • 腾讯codebuddy-cli重磅上线-国内首家支持全形态AI编程工具!
  • 基于PCL(Point Cloud Library)的点云高效处理方法
  • UVa1302/LA2417 Gnome Tetravex
  • STC Link1D电脑端口无法识别之升级固件