当前位置: 首页 > news >正文

深刻理解PyTorch中RNN(循环神经网络)的output和hn

零 药引·简短代码

import torch
from torch import nn
# seq_len/vocal_size = 3, batch_size=50, input_dimension=10
inputs = torch.randn((3, 50, 10))drnn1 = nn.RNN(input_size=10, num_layers=4, hidden_size=20)
outputs1, hn3 = drnn1(inputs)print(outputs1.shape)  # 所有时间步上,最后一层的隐藏状态
print(hn3.shape)  # 最后一个时间步上,所有层的隐藏状态
print(outputs1[-1, 0, :])
print(hn3[-1, 0, :])
torch.Size([3, 50, 20])
torch.Size([4, 50, 20])
tensor([-0.2765,  0.4682, -0.0817, -0.2519, -0.1041,  0.0172,  0.2613, -0.0619,-0.2646,  0.0591,  0.1749, -0.1277,  0.3200, -0.3987, -0.2516,  0.1340,-0.3838,  0.2305, -0.2042, -0.2924], grad_fn=<SliceBackward0>)
tensor([-0.2765,  0.4682, -0.0817, -0.2519, -0.1041,  0.0172,  0.2613, -0.0619,-0.2646,  0.0591,  0.1749, -0.1277,  0.3200, -0.3987, -0.2516,  0.1340,-0.3838,  0.2305, -0.2042, -0.2924], grad_fn=<SliceBackward0>)

一 深入理解

  • 理解 outputs1hn3 是理解 PyTorch 中 RNN(循环神经网络)的关键。用一个生动形象的比喻来理解它。

1.1 核心比喻:一场“接力赛跑”

想象一下,有一个非常特殊的接力赛跑团队,来处理你输入的数据。

  • 跑道:代表你的输入序列。你的 inputs 形状是 (3, 50, 10),其中 seq_len=3,所以这条跑道有 3 段(3个时间步)。
  • 运动员:代表你的数据批次batch_size=50,代表有50个赛道,所以同时有 50 个运动员 在各自的跑道上比赛。
  • 每段跑道的任务:运动员在每一段跑道上,都会遇到一个 “信息牌”(输入数据 input_size=10),他需要读懂这个信息牌,然后决定怎么跑。
  • 接力团队:代表你的 RNN 层num_layers=4,意味着每个运动员都有一个 4 人 的专属接力团队。
  • 每个接力队员的“大脑”:代表隐藏状态hidden_size=20,意味着每个队员的“大脑”能处理和记忆 20 个 维度的信息。

1.2 比赛开始!

现在,我们让这 50 个运动员(batch)同时开始他们的 3 段跑道(序列)比赛。

1.2.1 outputs1 - 每个运动员在每个跑道点的“即时状态报告”

outputs1 的形状是 (3, 50, 20)

  • 3 (seq_len): 代表跑道的 3 个检查点(或者说,3段跑道的终点)。
  • 50 (batch_size): 代表 50 个运动员
  • 20 (hidden_size): 代表每个运动员在检查点时,他最后一棒接力队员的“大脑”状态。

形象理解:
outputs1 就像是一台高速摄像机,它记录下了 所有运动员每一个检查点 冲线时的瞬间状态。

  • outputs1[0] (形状是 (50, 20)):这是 第1个检查点 的快照。它包含了 50个运动员 在跑完第1段后,他们各自的“第4棒”队员(因为是最深层)的“大脑”状态。
  • outputs1[1] (形状是 (50, 20)):这是 第2个检查点 的快照。包含了 50个运动员跑完第2段后的“第4棒”队员状态。
  • outputs1[2] (形状是 (50, 20)):这是 第3个检查点(终点) 的快照。包含了 50个运动员跑完全程后的“第4棒”队员状态。

一句话总结 outputs1

“我需要知道每个时间步(每个检查点)的最终输出是什么?”
比如,在文本情感分析中,你想知道句子中每个词之后的情感倾向,outputs1 就能提供这个信息。它包含了整个序列的完整记忆


1.2.2 hn3 - 所有运动员比赛结束后,他们整个接力团队的“最终状态报告”

hn3 的形状是 (4, 50, 20)

  • 4 (num_layers): 代表每个运动员的 4 人接力团队
  • 50 (batch_size): 代表 50 个运动员
  • 20 (hidden_size): 代表每个接力队员的“大脑”状态。
    形象理解:
    hn3 就像是比赛结束后,组委会收集的 所有运动员的完整团队报告
  • hn3[0] (形状是 (50, 20)):这是 所有运动员的“第1棒”队员 在完成自己全部任务(跑完3段)后的最终“大脑”状态。
  • hn3[1] (形状是 (50, 20)):这是 所有运动员的“第2棒”队员 的最终“大脑”状态。
  • hn3[2] (形状是 (50, 20)):这是 所有运动员的“第3棒”队员 的最终“大脑”状态。
  • hn3[3] (形状是 (50, 20)):这是 所有运动员的“第4棒”队员 的最终“大脑”状态。

关键点: hn3 只记录了最后时刻(第3个时间步)的状态。它不关心第1、第2个时间步发生了什么,它只关心整个接力赛跑结束后,每个团队里每个人的最终状态。


一句话总结 hn3

“我处理完整个序列后,最终的、浓缩的记忆是什么?”
比如,在文本分类中,你读完整句话后,只需要一个最终的情感判断(积极/消极)。hn3,特别是它的最后一层 hn3[-1],就非常适合作为这个最终判断的依据。它代表了处理完所有信息后的最终沉淀


1.3 两者关系和验证

现在,来看一个非常重要的关系,这能让你彻底明白它们之间的联系。

outputs1 记录了每个时间点最后一层的输出。

hn3 记录了最后一个时间点所有层的输出。


那么,outputs1最后一个时间点的数据,和 hn3最后一层的数据,是不是同一个东西?是的!

用代码验证一下:

import torch
from torch import nn
# 你的原始代码
inputs = torch.randn((3, 50, 10))
drnn1 = nn.RNN(input_size=10, num_layers=4, hidden_size=20)
outputs1, hn3 = drnn1(inputs)
# --- 验证环节 ---
# 1. 取出 outputs1 的最后一个时间步 (第2个索引) 的所有数据
# outputs1[-1] 的形状是 (50, 20)
last_output_from_outputs1 = outputs1[-1]
# 2. 取出 hn3 的最后一层 (第3个索引) 的所有数据
# hn3[-1] 的形状是 (50, 20)
last_layer_from_hn3 = hn3[-1]
# 3. 比较这两个张量是否完全相等
# torch.allclose() 是一个安全的方法,比较浮点数是否在误差范围内相等
are_they_equal = torch.allclose(last_output_from_outputs1, last_layer_from_hn3)
print(f"outputs1 的形状: {outputs1.shape}")
print(f"hn3 的形状: {hn3.shape}")
print("-" * 30)
print(f"outputs1[-1] (最后一个时间步的输出) 的形状: {last_output_from_outputs1.shape}")
print(f"hn3[-1] (最后一层的隐藏状态) 的形状: {last_layer_from_hn3.shape}")
print("-" * 30)
print(f"这两个张量是否相等? {are_they_equal}")

运行结果:

outputs1 的形状: torch.Size([3, 50, 20])
hn3 的形状: torch.Size([4, 50, 20])
------------------------------
outputs1[-1] (最后一个时间步的输出) 的形状: torch.Size([50, 20])
hn3[-1] (最后一层的隐藏状态) 的形状: torch.Size([50, 20])
------------------------------
这两个张量是否相等? True

结果解读:
True 这个结果完美地印证了我们的比喻:

在比赛结束的那一刻(最后一个时间步),摄影师拍下的“最后一棒队员冲线照片”(outputs1[-1]),和组委会收到的“团队报告中关于最后一棒队员的描述”(hn3[-1]),是同一个东西。

1.4 总结表格

特性outputs1hn3
比喻即时状态报告 (每个检查点的快照)最终状态报告 (赛后团队总结)
关注点时间维度 (序列中的每一步)层次维度 (网络中的每一层)
形状(seq_len, batch_size, hidden_size)(num_layers, batch_size, hidden_size)
包含信息所有时间步上,最后一层的隐藏状态最后一个时间步上,所有层的隐藏状态
典型用途序列标注 (如词性标注)、语音识别 (需要每个时间步的输出)序列分类 (如情感分析)、只关心最终结果的场景
关键关系outputs1[-1] 等于 hn3[-1]hn3[-1] 等于 outputs1[-1]

二 RNN中 outputs1hn3 可视化理解图 (ASCII版)

  • 再次以“接力赛跑”为喻,用字符画来描绘整个过程。

2.1 整体流程概览

想象一下,我们有50条并行的跑道(代表50个batch)。我们只看其中一条跑道的情况,因为其他49条是完全一样的。

+-------------------+     +-------------------+     +-------------------+
|  Time Step 1      |     |  Time Step 2      |     |  Time Step 3      |
|  (第一棒)          |    |  (第二棒)          |      |  (第三棒)         |
|                   |     |                   |     |                   |
|  Input: (10,)     |     |  Input: (10,)     |     |  Input: (10,)     |
|      (第一段数据)  |     |      (第二段数据)  |     |      (第三段数据)  |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|  |  Layer 1    |  |---->|  |  Layer 1    |  |---->|  |  Layer 1    |  |
|  |  (h_11)     |  |     |  |  (h_21)     |  |     |  |  (h_31)     |  |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|  |  Layer 2    |  |---->|  |  Layer 2    |  |---->|  |  Layer 2    |  |
|  |  (h_12)     |  |     |  |  (h_22)     |  |     |  |  (h_32)     |  |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|  |  Layer 3    |  |---->|  |  Layer 3    |  |---->|  |  Layer 3    |  |
|  |  (h_13)     |  |     |  |  (h_23)     |  |     |  |  (h_33)     |  |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|  |  Layer 4    |  |---->|  |  Layer 4    |  |---->|  |  Layer 4    |  |
|  |  (h_14)     |  |     |  |  (h_24)     |  |     |  |  (h_34)     |  | <----+
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |      |
+-------------------+     +-------------------+     +-------------------+      |^                         ^                         ^                    ||                         |                          |                   ||  (初始隐藏 状态)          |                          |                   |+-------------------------+-------------------------+--------------------+|| (hn3 的来源)v

图解:

  • 横轴是时间: 从左到右,代表你的输入序列的3个时间步。
  • 纵轴是网络层: 每个时间步内部,数据从下往上流经4个网络层。
  • 箭头 --->: 代表信息的传递。横向箭头是同一层在不同时间步之间的传递(接力赛中的“交接棒”),纵向箭头是同一时间步内不同层之间的传递。
  • h_ij: 代表第 i 个时间步,第 j 个层的隐藏状态。例如,h_23 就是第2个时间步,第3个层的输出。

2.2 outputs1 是什么?(沿途的终点线摄影)

outputs1 就像是在每一棒的终点线都架设了一台高速摄像机,只拍摄最后一棒运动员(Layer 4)冲线时的照片

+-------------------+     +-------------------+     +-------------------+
|  Time Step 1      |     |  Time Step 2      |     |  Time Step 3      |
|                   |     |                   |     |                   |
|  ... (Layers 1-3) |     |  ... (Layers 1-3) |     |  ... (Layers 1-3) |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|  |  Layer 4    |  |     |  |  Layer 4    |  |     |  |  Layer 4    |  |
|  |  (h_14)     |  |     |  |  (h_24)     |  |     |  |  (h_34)     |  |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|      [ Photo 1 ]  |     |    [ Photo 2 ]    |     |  [ Photo 3 ]      |
|      (h_14 的快照) |     |    (h_24 的快照)   |     |   (h_34 的快照)    |
+-------------------+     +-------------------+     +-------------------+|                         |                         ||-------------------------+-------------------------||v+---------------------------------------+|        outputs1 (相册)                 ||  形状:  (3, 50, 20)                    ||  内容: [Photo1, Photo2, Photo3]        ||        [h_14,   h_24,   h_34]         |+---------------------------------------+

outputs1 总结:

  • 内容: 它收集了 所有时间步最后一层 的隐藏状态。
  • 形状 (3, 50, 20):
    • 3: 对应3个时间步(3张照片)。
    • 50: 对应50个batch(50个运动员,每人一本相册)。
    • 20: 对应隐藏层大小(每张照片有20个维度的信息)。
  • 用途: 当你需要序列中每一步的输出时,比如给每个词标注词性,或者在语音识别中识别每一帧的声音。

2.3 hn3 是什么?(比赛结束后的团队总结报告)

  • hn3 就像比赛结束后,教练员记录的最终总结报告。这份报告记录了最后一棒冲线时,所有4名队员(所有层)的最终状态
+-------------------+     +-------------------+     +-------------------+
|  Time Step 1      |     |  Time Step 2      |     |  Time Step 3      |
|                   |     |                   |     |                   |
|  ... (Layers 1-3) |     |  ... (Layers 1-3) |     |  ... (Layers 1-3) |
|         |         |     |         |         |     |         |         |
|         v         |     |         v         |     |         v         |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
|  |  Layer 1    |  |     |  |  Layer 1    |  |     |  |  Layer 1    |  | --+
|  |  (h_11)     |  |     |  |  (h_21)     |  |     |  |  (h_31)     |  |   |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |   |
|         |         |     |         |         |     |         |         |   |
|         v         |     |         v         |     |         v         |   |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |   |
|  |  Layer 2    |  |     |  |  Layer 2    |  |     |  |  Layer 2    |  | --+
|  |  (h_12)     |  |     |  |  (h_22)     |  |     |  |  (h_32)     |  |   |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |   |
|         |         |     |         |         |     |         |         |   |
|         v         |     |         v         |     |         v         |   |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |   |
|  |  Layer 3    |  |     |  |  Layer 3    |  |     |  |  Layer 3    |  | --+
|  |  (h_13)     |  |     |  |  (h_23)     |  |     |  |  (h_33)     |  |   |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |   |
|         |         |     |         |         |     |         |         |   |
|         v         |     |         v         |     |         v         |   |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  | --+
|  |  Layer 4    |  |     |  |  Layer 4    |  |     |  |  Layer 4    |  |
|  |  (h_14)     |  |     |  |  (h_24)     |  |     |  |  (h_34)     |  |
|  +-------------+  |     |  +-------------+  |     |  +-------------+  |
+-------------------+     +-------------------+     +-------------------+|| (比赛结束,收集所有队员状态)v+------------------------------------------+|           hn3 (总结报告)                  ||  形状:  (4, 50, 20)                       ||  内容: [ Layer1, Layer2, Layer3, Layer4]  ||       [  h_31,   h_32,   h_33,   h_34]    |+-------------------------------------------+

hn3 总结:

  • 内容: 它收集了 最后一个时间步所有层 的隐藏状态。
  • 形状 (4, 50, 20):
    • 4: 对应4个网络层(报告里有4名队员的总结)。
    • 50: 对应50个batch(50个运动员,每人一份报告)。
    • 20: 对应隐藏层大小(每个队员的状态有20个维度)。
  • 用途: 当你只关心整个序列的最终结果时,比如判断一段话的整体情感,或者将一整句话翻译成一个概括性的词。

2.4 核心交汇点 (最重要的关系)

现在,我们把上面两个图的关键部分合在一起看:

      (来自 outputs1 的收集)|v
+-------------------+
|  Time Step 3      |
|                   |
|  ... (Layers 1-3) |
|         |         |
|         v         |
|  +-------------+  |
|  |  Layer 4    |  | <------------------------------------+
|  |  (h_34)     |  |                                     |
|  +-------------+  |                                     |
+-------------------+                                     ||                                                   || (来自 hn3 的收集)                                  |v                                                  |
+---------------------------------------+                |
|           hn3 (总结报告)               |                |
|  形状:  (4, 50, 20)                   |                 |
|  内容: [ h_31, h_32, h_33, h_34 ]     |                 |
+---------------------------------------+                |^                                                  ||                                                  |+--------------------------------------------------+|outputs1[-1]  是  h_34hn3[-1]      也是  h_34所以 outputs1[-1] == hn3[-1]

核心关系总结:

  • outputs1 的最后一个元素 (outputs1[-1]),就是 Time Step 3Layer 4 的输出 h_34
  • hn3 的最后一个元素 (hn3[-1]),也是 Time Step 3Layer 4 的输出 h_34
  • 它们指向的是完全同一个数据!就像一个运动员冲线的瞬间,既被终点摄像机拍下(成为outputs1的一部分),也被记录在团队的最终报告里(成为hn3的一部分)。
http://www.dtcms.com/a/394903.html

相关文章:

  • 大模型如何赋能写作:从创作到 MCP 自动发布的全链路解析
  • C++设计模式之创建型模式:工厂方法模式(Factory Method)
  • 传输层协议——UDP/TCP
  • 三板汇茶咖空间签约“可信资产IPO与数链金融RWA”链改2.0项目联合实验室
  • 【MySQL】MySQL 表文件误删导致启动失败及无法外部连接解决方案
  • LVS简介
  • 如何将联系人从iPhone转移到iPhone的7种方法
  • 『 MySQL数据库 』MySQL复习(一)
  • 3005. 最大频率元素计数
  • ACP(七)优化RAG应用提升问答准确度
  • 鸿蒙:使用bindPopup实现气泡弹窗
  • Langchan4j 框架 AI 无限循环调用文件创建工具解决方案记录
  • Python GIS 开发里最核心的4个基础组件(理论+实操篇)
  • 关于跨域和解决方案
  • 学习日报 20250921|LoadingCache
  • 聚力赋能|竹云受邀出席2025华为全联接大会
  • 抓取 Dump 文件与 WinDbg 使用详解:定位 Windows 程序异常的利器
  • 计算机组成原理:指令周期
  • 老题新解|简单算术表达式求值
  • RustFS与其他新兴存储系统(如SeaweedFS)相比有哪些优势和劣势?
  • WPS标点符号换行问题解决
  • 开发团队的文档自动化革命:WPS+cpolar实战录
  • 【Linux】文本编辑器Vim
  • flink1.18下游配置多个sink
  • 如何删除 MySQL 数据库中的所有数据表 ?
  • win10加域后,控制面板中的,internet 时间就没有了
  • Unity移动平台笔记
  • 【图像算法 - 27】基于YOLOv12与OpenCV的无人机智能检测系统
  • html css js网页制作成品——圣罗兰护肤html+css+js 4页附源码
  • 21届-3年-Java面经-华为od