当前位置: 首页 > news >正文

循环神经网络详解

序列模型

序列数据的本质

在现实世界中,时间维度构成了我们理解事件的基础框架。从语言交流到金融市场,从生物序列到传感器网络,时间序列数据无处不在:

现实世界
时间序列数据
自然语言
语音信号
股票价格
DNA序列

序列模型正是为处理这类具有时间或顺序依赖关系的数据而设计的计算框架。其核心挑战在于捕捉元素之间的长期依赖关系,同时处理变长输入和输出序列。

类型

  • 语音识别,输入一段语音输出对应文字
    在这里插入图片描述
  • 情感分类,输入一段表示用户情感的文字,输出情感类别或者评分
    在这里插入图片描述
  • 机器翻译,两种语言的互相翻译
    在这里插入图片描述

为什么CNN在序列模型中表现不佳?

尽管卷积神经网络(CNN)在图像领域大放异彩,但在处理序列数据时面临根本性局限:

  1. 固定感受野

    有限范围
    输入序列
    卷积核
    局部特征
    输出
    • 只能捕捉局部依赖
    • 难以建模长距离依赖
    • 序列数据的输入输出长度不固定
  2. 计算效率问题

    • 为捕捉长距离依赖需要极深网络
    • 参数量随感受野线性增长
  3. 时间对称性

    • 标准卷积无方向性
    • 序列处理需考虑时间方向(过去→未来)

循环神经网络

基础循环网络介绍

基础循环网络示例图
基础循环网络示例图

数学表示
s0=0s_0=0s0=0
st=g1(Uxt+Wst−1+ba)s_t = g1(Ux_t + Ws_{t-1} + b_a)st=g1(Uxt+Wst1+ba)
ot=g2(Vst+by)o_t = g2(Vs_t + b_y)ot=g2(Vst+by)

其中:

  • sts_tst:每一个隐层的输出
  • xtx_txt:每一个时刻的输入
  • oto_tot:每一个时刻的输出
  • WWW:权重矩阵
  • g1,g2g1,g2g1,g2:激活函数

类型

 循环神经网络类型示例图
循环神经网络类型示例图
  1. 一对一

    • 固定输入 → 固定输出
    • 示例:图像分类
  2. 一对多

    • 固定输入 → 序列输出
    • 示例:图像文字描述
  3. 多对一

    • 序列输入 → 输出
    • 示例:情感分析
  4. 多对多

    • 序列输入 → 序列输出
    • 示例:机器翻译
    • 变体:编码器-解码器架构
  5. 同步多对多

    • 同步序列输入 → 同步输出
    • 示例:文本生成,视频每一帧分类

序列生成案例解析:语言建模

以生成句子"我 昨天 上学 迟到 了 e"为例(e表示结束符):

序列生成过程

序列生成案例示意图
序列生成案例示意图

词的表示

为了能够让整个网络能够理解我们的输入(英文/中文等),需要将词用向量表示:

  • ​​建立一个包含所有序列词的词典​​包含(开始和标志的两个特殊词,以及没有出现过的词等) ,每个词在词典中有唯一编号。​

  • ​ 任意一个词可用 N 维 one-hot 向量表示​​,向量维度 N = 词典中词的个数。

    词的表示示意图
    词的表示示意图

输出表示:Softmax概率分布

RNN在每个时间步输出词汇表上的概率分布:

P(wt∣w<t)=Softmax(Woht+bo)P(w_t | w_{<t}) = \text{Softmax}(W_oh_t + b_o)P(wtw<t)=Softmax(Woht+bo)

Softmax函数
σ(z)j=ezj∑k=1Kezkfor j=1,...,K\sigma(z)_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} \quad \text{for } j=1,...,Kσ(z)j=k=1Kezkezjfor j=1,...,K

示例t=2时刻输出概率:

  • “上学”:0.45
  • “跑步”:0.25
  • “吃饭”:0.15

RNN的数学引擎

矩阵运算表示

假设以上面的例子:对于网络当中某一时刻的公式中

  1. st=relu(Uxt+Wst−1)s_t= relu(Ux_t+Ws_{t-1})st=relu(Uxt+Wst1)
  2. ot=softmax(Vst)o_t = softmax(Vs_t)ot=softmax(Vst)
    在这里插入图片描述
  • 1、形状表示:​​[n,m]×[m,1]+[n,n]×[n,1]=[n,1]​​[n,m]×[m,1]+[n,n]×[n,1]=[n,1]​​[n,m]×[m,1]+[n,n]×[n,1]=[n,1]

    • 则矩阵 ​​U​​ 的维度是 n×m,矩阵 ​​W​​ 的维度是 n×n
    • m:词的个数,n:输出 s 的维度
  • 2、形状表示:​​[m,n]×[n,1]=[m,1][m,n]×[n,1]=[m,1][m,n]×[n,1]=[m,1]

    • 矩阵 ​​V​​ 的维度是 m×n

交叉熵损失函数

对于序列预测任务,使用时间步交叉熵的累积:

L=−∑t=1T∑k=1Kyt,klog⁡y^t,k\mathcal{L} = -\sum_{t=1}^T \sum_{k=1}^K y_{t,k} \log \hat{y}_{t,k}L=t=1Tk=1Kyt,klogy^t,k

其中:

  • TTT:序列长度
  • KKK:词汇表大小
  • yty_tyt:真实标签的独热编码
  • y^t\hat{y}_ty^t:预测概率分布

时序反向传播算法(BPTT)

BPTT算法原理

对于RNN来说有一个时间概念,需要把梯度沿时间通道传播的 BP 算法,所以称为 Back Propagation Through Time-NPTT

我们的​​目标是​​计算误差关于参数 U、V、W 以及 bx、by 的梯度,使用梯度下降法学习参数。(因三组参数共享,需将一个训练实例在​​每时刻的梯度相加​​)

  • 1、要求:​​计算每个时间步的梯度: t=0, t=1, t=2, t=3, t=4,将各时刻梯度​​相加的结果​​作为参数 W 的更新梯度值。
  • 2、求不同参数的导数步骤:​​
    • 最后一个 cell:​​
      • 计算最后一个时刻的交叉熵损失对 sts_tst的梯度,记忆交叉熵损失对于sts_tst,VVV,byb_yby的导数
    • 最后一个 cell 前面的 cell:​​
      • ​​第一步:​​计算当前层损失对当前隐层状态sts_tst的梯度 +++ 上一层相对于sts_tst的损失

      • ​​第二步:​​计算 tanh 激活函数的导数:

      • ​第三步:​​计算 Uxt+Wst−1+baUx_t+Ws_{t-1}+b_aUxt+Wst1+ba的对于不同参数的导数

梯度计算

考虑损失函数对参数WWW的梯度:

∂L∂W=∑t=1T∂L∂st∂st∂xt∂xt∂W\frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \frac{\partial \mathcal{L}}{\partial s_t} \frac{\partial s_t}{\partial x_t} \frac{\partial x_t}{\partial W}WL=t=1TstLxtstWxt
考虑损失函数对参数UUU的梯度:

∂L∂U=∑t=1T∂L∂st∂st∂xt∂xt∂U\frac{\partial \mathcal{L}}{\partial U} = \sum_{t=1}^T \frac{\partial \mathcal{L}}{\partial s_t} \frac{\partial s_t}{\partial x_t} \frac{\partial x_t}{\partial U}UL=t=1TstLxtstUxt

考虑损失函数对参数bbb的梯度:

∂L∂b=∑t=1T∂L∂st∂st∂xt∂xt∂b\frac{\partial \mathcal{L}}{\partial b} = \sum_{t=1}^T \frac{\partial \mathcal{L}}{\partial s_t} \frac{\partial s_t}{\partial x_t} \frac{\partial x_t}{\partial b}bL=t=1TstLxtstbxt

BPTT计算示例

考虑序列长度T=3的简化RNN:

  1. 前向传播

    • s1=tanh⁡(Wx1+bh)s_1 = \tanh(Wx_1 + b_h)s1=tanh(Wx1+bh)
    • s2=tanh⁡(Wx2+Ws1+bh)s_2 = \tanh(Wx_2 + Ws_1 + b_h)s2=tanh(Wx2+Ws1+bh)
    • s3=tanh⁡(Wx3+Ws2+bh)s_3 = \tanh(Wx_3 + Ws_2 + b_h)s3=tanh(Wx3+Ws2+bh)
  2. 反向传播(计算∂L∂W\frac{\partial \mathcal{L}}{\partial W}WL):
    ∂L∂W=δ3∂s3∂W+δ2∂s2∂W+δ1∂s1∂W\frac{\partial \mathcal{L}}{\partial W} = \delta_3 \frac{\partial s_3}{\partial W} + \delta_2 \frac{\partial s_2}{\partial W} + \delta_1 \frac{\partial s_1}{\partial W}WL=δ3Ws3+δ2Ws2+δ1Ws1

    其中:

    • δ3=∂L∂o3∂o3∂s3\delta_3 = \frac{\partial \mathcal{L}}{\partial o_3} \frac{\partial o_3}{\partial s_3}δ3=o3Ls3o3
    • δ2=δ3∂s3∂s2∂s2∂s2+直接梯度\delta_2 = \delta_3 \frac{\partial s_3}{\partial s_2} \frac{\partial s_2}{\partial s_2} + \text{直接梯度}δ2=δ3s2s3s2s2+直接梯度
    • δ1=δ2∂s2∂s1∂s1∂s1\delta_1 = \delta_2 \frac{\partial s_2}{\partial s_1} \frac{\partial s_1}{\partial s_1}δ1=δ2s1s2s1s1
  3. 梯度累积
    ∂L∂W=δ3s2+δ2s1+δ1⋅0\frac{\partial \mathcal{L}}{\partial W} = \delta_3 s_2 + \delta_2 s_1 + \delta_1 \cdot 0WL=δ3s2+δ2s1+δ10

RNN的核心挑战:梯度动力学

梯度消失问题

现象:长期依赖梯度指数级衰减

指数衰减
时间t
梯度
t-n

数学本质
∥∂st∂sk∥≈∥∏i=kt−1∂si+1∂si∥≤(γ)t−k\left\| \frac{\partial s_t}{\partial s_k} \right\| \approx \left\| \prod_{i=k}^{t-1} \frac{\partial s_{i+1}}{\partial s_i} \right\| \leq (\gamma)^{t-k}sksti=kt1sisi+1(γ)tk
其中γ<1\gamma < 1γ<1

影响

  • 早期时间步参数更新不足
  • 难以学习长期依赖
  • 模型偏向短期模式

梯度爆炸问题

现象:梯度值指数级增长

指数增长
时间t
梯度
t-n

数学本质
∥∂st∂sk∥≥(γ)t−k\left\| \frac{\partial s_t}{\partial s_k} \right\| \geq (\gamma)^{t-k}skst(γ)tk
其中γ>1\gamma > 1γ>1

影响

  • 参数更新过大
  • 训练不稳定
  • NaN损失值

RNN在现代AI中的地位

尽管Transformer异军突起,RNN仍在特定领域保持独特优势:

  1. 时间序列预测

    • 金融市场价格预测
    • 工业设备故障预警
    • 医疗监测
  2. 实时语音处理

    • 语音识别
    • 语音合成
    • 实时翻译
  3. 控制与机器人

    • 运动控制
    • 传感器融合
    • 自适应决策
  4. 神经科学建模

    • 大脑信息处理模拟
    • 认知过程研究

结语

循环神经网络作为序列建模的奠基者,其核心价值在于对时间本质的深刻把握。在信息以连续流形式存在的世界中,RNN提供了一种自然的计算范式:

  • 状态连续性:通过隐藏状态传递历史信息
  • 时间局部性:逐步处理序列元素
  • 动态计算:适应变长输入输出

“循环神经网络教会我们,时间不是离散的帧,而是连续的流;记忆不是静态的存储,而是动态的重构。”

尽管新型架构不断涌现,RNN所确立的序列建模原则——状态传递、时间展开、循环计算——仍将影响未来AI的发展方向。理解RNN不仅是掌握一项技术,更是理解智能系统如何处理时间这一根本维度的重要窗口。

http://www.dtcms.com/a/271502.html

相关文章:

  • cherryStudio electron因为环境问题无法安装解决方法或打包失败解决方法
  • NLP自然语言处理04 transformer架构模拟实现
  • Git版本控制完全指南:从入门到实战(简单版)
  • 【02】MFC入门到精通——MFC 手动添加创建新的对话框模板
  • 【PyTorch】PyTorch中torch.nn模块的全连接层
  • C++每日刷题 day2025.7.09
  • 备受期待的 MMORPG 游戏《侍魂R》移动端现已上线 Sui
  • RK3588 buildroot 解决软件包无法下载
  • 用户查询优惠券之缓存击穿
  • RAC-CELL(小区)处理
  • Ubuntu连接不上网络问题(Network is unreachable)
  • 国产航顺HK32F030M: 串口调试debug,重定向c库函数printf到串口,重定向后可使用printf函数
  • 记一次接口优化历程 CountDownLatch
  • C语言模块化编程思维以及直流电机控制(第四天)
  • 深度学习——损失函数
  • 【使用Flask基于PaddleOCR3.0开发一个接口 调用时报错RuntimeError: std::exception】
  • JVM调优实战指南:让Java程序性能飞升的奥秘
  • PanTS: The Pancreatic Tumor Segmentation Dataset
  • 使用anaconda创建基础环境
  • 数据分析框架和方法
  • 数据分析-名词
  • pip 安装加速指南:配置国内镜像源(中国科技大学、清华、阿里云等)
  • Java武林:虚拟机之道 第七章:秘籍解析 - JVM调优参数
  • 经验分享-没有xcode也可以上传App Store Connect
  • S7-1500——(一)从入门到精通1、基于TIA 博途解析PLC程序结构(一)
  • c语言中的数组II
  • 景观桥 涵洞 城门等遮挡物对汽车安全性的影响数学建模和计算方法,需要收集那些数据
  • 周立功汽车软件ZXDoc深度解析:新能源汽车开发新基建的破局之道
  • java 语法类新特性总结
  • 【王树森推荐系统】排序05:排序模型的特征