当前位置: 首页 > news >正文

【Torch】nn.GRU算法详解

1. 输入输出

  1. 输入张量

    • 默认形状:(seq_len, batch_size, input_size)
    • batch_first=True(batch_size, seq_len, input_size)
    • 含义:序列长度 × 批大小 × 每步特征维度
  2. 可选初始隐状态

    • 形状:(num_layers * num_directions, batch_size, hidden_size)
    • 默认为全零张量。如果要自定义,需提供此形状的 h0
  3. 输出
    调用 output, h_n = gru(x, h0) 返回两部分:

    • output:所有时间步的隐藏状态序列
      • 形状:
        • 默认:(seq_len, batch_size, num_directions * hidden_size)
        • batch_first=True(batch_size, seq_len, num_directions * hidden_size)
      • 含义:每个时间步的隐藏状态,可以直接接全连接或其它后续层。
    • h_n:最后一个时间步的隐藏状态
      • 形状:(num_layers * num_directions, batch_size, hidden_size)
      • 含义:每一层(及方向)在序列末尾的隐藏状态,常用于初始化下一个序列或分类任务。

2. 构造函数参数详解

nn.GRU(input_size: int,hidden_size: int,num_layers: int = 1,bias: bool = True,batch_first: bool = False,dropout: float = 0.0,bidirectional: bool = False
)
参数类型含义
input_sizeint输入特征维度,即每步输入向量的大小。
hidden_sizeint隐状态(隐藏层)维度,也决定输出特征维度(单向时即 hidden_size)。
num_layersint堆叠的 GRU 层数(深度),默认为 1。
biasbool是否使用偏置;当为 False 时,所有线性变换均无 bias。
batch_firstbool是否将批量维放到第二维(True),默认序列维在最前(False)。
dropoutfloat除最后一层外,每层输出后使用的 Dropout 比例;仅在 num_layers>1 时生效。
bidirectionalbool是否使用双向 RNN;若 True,则隐状态和输出维度翻倍。

3. 输出含义详解

  • output

    • 大小:[..., num_directions * hidden_size]
    • 如果 bidirectional=Falsenum_directions=1;否则 =2
    • output[t, b, :](或在 batch_first 模式下 output[b, t, :])表示第 t 步第 b 个样本的隐藏状态。
  • h_n

    • 大小:(num_layers * num_directions, batch_size, hidden_size)
    • 维度索引含义:
      • 维度 0:层数 × 方向(例如 3 层双向时索引 0–5,对应层1正向、层1反向、层2正向…)
      • 维度 1:批内样本索引
      • 维度 2:隐藏状态向量

4. 使用注意事项

  1. batch_first 的选择

    • 若后续直接接全连接层、BatchNorm 等,更习惯 batch_first=True;否则可用默认格式节省一次转置。
  2. 双向与输出维度

    • bidirectional=True 时,output 的最后一维和 h_nhidden_size 均会翻倍,需要相应修改下游网络维度。
  3. Dropout 的生效条件

    • 只有在 num_layers > 1 并且 dropout > 0 时,才会在各层间插入 Dropout;单层时不会应用。
  4. 初始隐状态

    • 默认为零。若在两个连续序列之间保持状态(stateful RNN),可将上一次的 h_n 作为下一次的 h0
  5. PackedSequence

    • 对变长序列,可用 torch.nn.utils.rnn.pack_padded_sequence 输入,输出再用 pad_packed_sequence 恢复,对长短不一的序列批处理很有用。
  6. 性能与稳定性

    • GRU 相比 LSTM 参数更少、速度稍快,但有时在长期依赖或梯度流问题上略不如 LSTM。
    • 可在多层 RNN 之间加 LayerNorm 或 Residual 连接,提升深度模型的收敛和稳定性。

简单示例

import torch, torch.nn as nn# 定义单层单向 GRU
gru = nn.GRU(input_size=10, hidden_size=20, num_layers=2,batch_first=True, dropout=0.1, bidirectional=True)# 输入:batch=8, seq_len=15, features=10
x = torch.randn(8, 15, 10)# 默认 h0 为零
output, h_n = gru(x)
print(output.shape)  # (8, 15, 2*20)  双向,所以 hidden_size*2
print(h_n.shape)     # (2*2, 8, 20)  num_layers=2, num_directions=2

相关文章:

  • Java 类加载机制详解
  • 高级版 Web Worker 封装(含 WorkerPool 调度池 + 超时控制)
  • 渗透测试指南(CSMSF):Windows 与 Linux 系统中的日志与文件痕迹清理
  • 【时时三省】(C语言基础)怎样定义指针变量
  • 无人驾驶汽车运动控制分为纵向控制和横向控制
  • 软件更新 | 从数据到模型,全面升级!TSMaster新版助力汽车研发新突破
  • AIGC工具平台-FishSpeech零样本语音合成
  • 用 GitHub Issues 做任务管理和任务 List,简单好用!
  • 《Redis高并发优化策略与规范清单:从开发到运维的全流程指南》
  • 关于变换矩阵的计算
  • 同源数据互补修复机制:从DNA修复到分布式系统的可靠性设计
  • fiddler+安卓模拟器,解决无网络、抓不到https问题
  • 【Linux网络编程】序列化与反序列化
  • 组件化设计核心:接口与实现分离(C++)
  • JAVA学习-练习试用Java实现“TensorFlow/Deeplearning4j:利用DL4J构建卷积神经网络进行图像分类”
  • ios签名错误的解决办法
  • 百胜软件胜券AI:打造智慧零售运营新范式
  • 布瑞琳BRANEW:高端洗护领航者,铸就品质生活新典范
  • TestCafe 全解析:免费开源的 E2E 测试解决方案实战指南
  • 【C#】C#异步编程:异步延时 vs 阻塞延时深度对比
  • 搭建网站用服务器还是虚拟主机/网络软文广告
  • 重庆秀山网站建设价格/西安疫情最新数据
  • 做网站公司在哪/网址大全浏览器下载
  • 怎么做根优酷差不多的网站/北京环球影城每日客流怎么看
  • wordpress 判断 手机/seo引擎
  • 网站支付怎么做的/可视化网页制作工具