深刻理解PyTorch中RNN(循环神经网络)的output和hn
零 药引·简短代码
import torch
from torch import nn
# seq_len/vocal_size = 3, batch_size=50, input_dimension=10
inputs = torch.randn((3, 50, 10))drnn1 = nn.RNN(input_size=10, num_layers=4, hidden_size=20)
outputs1, hn3 = drnn1(inputs)print(outputs1.shape) # 所有时间步上,最后一层的隐藏状态
print(hn3.shape) # 最后一个时间步上,所有层的隐藏状态
print(outputs1[-1, 0, :])
print(hn3[-1, 0, :])
torch.Size([3, 50, 20])
torch.Size([4, 50, 20])
tensor([-0.2765, 0.4682, -0.0817, -0.2519, -0.1041, 0.0172, 0.2613, -0.0619,-0.2646, 0.0591, 0.1749, -0.1277, 0.3200, -0.3987, -0.2516, 0.1340,-0.3838, 0.2305, -0.2042, -0.2924], grad_fn=<SliceBackward0>)
tensor([-0.2765, 0.4682, -0.0817, -0.2519, -0.1041, 0.0172, 0.2613, -0.0619,-0.2646, 0.0591, 0.1749, -0.1277, 0.3200, -0.3987, -0.2516, 0.1340,-0.3838, 0.2305, -0.2042, -0.2924], grad_fn=<SliceBackward0>)
一 深入理解
- 理解
outputs1
和hn3
是理解 PyTorch 中 RNN(循环神经网络)的关键。用一个生动形象的比喻来理解它。
1.1 核心比喻:一场“接力赛跑”
想象一下,有一个非常特殊的接力赛跑团队,来处理你输入的数据。
- 跑道:代表你的输入序列。你的
inputs
形状是(3, 50, 10)
,其中seq_len=3
,所以这条跑道有 3 段(3个时间步)。 - 运动员:代表你的数据批次。
batch_size=50
,代表有50个赛道,所以同时有 50 个运动员 在各自的跑道上比赛。 - 每段跑道的任务:运动员在每一段跑道上,都会遇到一个 “信息牌”(输入数据
input_size=10
),他需要读懂这个信息牌,然后决定怎么跑。 - 接力团队:代表你的 RNN 层。
num_layers=4
,意味着每个运动员都有一个 4 人 的专属接力团队。 - 每个接力队员的“大脑”:代表隐藏状态。
hidden_size=20
,意味着每个队员的“大脑”能处理和记忆 20 个 维度的信息。
1.2 比赛开始!
现在,我们让这 50 个运动员(batch)同时开始他们的 3 段跑道(序列)比赛。
1.2.1 outputs1
- 每个运动员在每个跑道点的“即时状态报告”
outputs1
的形状是 (3, 50, 20)
。
3
(seq_len): 代表跑道的 3 个检查点(或者说,3段跑道的终点)。50
(batch_size): 代表 50 个运动员。20
(hidden_size): 代表每个运动员在检查点时,他最后一棒接力队员的“大脑”状态。
形象理解:
outputs1
就像是一台高速摄像机,它记录下了 所有运动员 在 每一个检查点 冲线时的瞬间状态。
outputs1[0]
(形状是(50, 20)
):这是 第1个检查点 的快照。它包含了 50个运动员 在跑完第1段后,他们各自的“第4棒”队员(因为是最深层)的“大脑”状态。outputs1[1]
(形状是(50, 20)
):这是 第2个检查点 的快照。包含了 50个运动员跑完第2段后的“第4棒”队员状态。outputs1[2]
(形状是(50, 20)
):这是 第3个检查点(终点) 的快照。包含了 50个运动员跑完全程后的“第4棒”队员状态。
一句话总结 outputs1
:
“我需要知道每个时间步(每个检查点)的最终输出是什么?”
比如,在文本情感分析中,你想知道句子中每个词之后的情感倾向,outputs1
就能提供这个信息。它包含了整个序列的完整记忆。
1.2.2 hn3
- 所有运动员比赛结束后,他们整个接力团队的“最终状态报告”
hn3
的形状是 (4, 50, 20)
。
4
(num_layers): 代表每个运动员的 4 人接力团队。50
(batch_size): 代表 50 个运动员。20
(hidden_size): 代表每个接力队员的“大脑”状态。
形象理解:
hn3
就像是比赛结束后,组委会收集的 所有运动员的完整团队报告。hn3[0]
(形状是(50, 20)
):这是 所有运动员的“第1棒”队员 在完成自己全部任务(跑完3段)后的最终“大脑”状态。hn3[1]
(形状是(50, 20)
):这是 所有运动员的“第2棒”队员 的最终“大脑”状态。hn3[2]
(形状是(50, 20)
):这是 所有运动员的“第3棒”队员 的最终“大脑”状态。hn3[3]
(形状是(50, 20)
):这是 所有运动员的“第4棒”队员 的最终“大脑”状态。
关键点: hn3
只记录了最后时刻(第3个时间步)的状态。它不关心第1、第2个时间步发生了什么,它只关心整个接力赛跑结束后,每个团队里每个人的最终状态。
一句话总结 hn3
:
“我处理完整个序列后,最终的、浓缩的记忆是什么?”
比如,在文本分类中,你读完整句话后,只需要一个最终的情感判断(积极/消极)。hn3
,特别是它的最后一层hn3[-1]
,就非常适合作为这个最终判断的依据。它代表了处理完所有信息后的最终沉淀。
1.3 两者关系和验证
现在,来看一个非常重要的关系,这能让你彻底明白它们之间的联系。
outputs1
记录了每个时间点的最后一层的输出。
hn3
记录了最后一个时间点的所有层的输出。
那么,outputs1
的最后一个时间点的数据,和 hn3
的最后一层的数据,是不是同一个东西?是的!
用代码验证一下:
import torch
from torch import nn
# 你的原始代码
inputs = torch.randn((3, 50, 10))
drnn1 = nn.RNN(input_size=10, num_layers=4, hidden_size=20)
outputs1, hn3 = drnn1(inputs)
# --- 验证环节 ---
# 1. 取出 outputs1 的最后一个时间步 (第2个索引) 的所有数据
# outputs1[-1] 的形状是 (50, 20)
last_output_from_outputs1 = outputs1[-1]
# 2. 取出 hn3 的最后一层 (第3个索引) 的所有数据
# hn3[-1] 的形状是 (50, 20)
last_layer_from_hn3 = hn3[-1]
# 3. 比较这两个张量是否完全相等
# torch.allclose() 是一个安全的方法,比较浮点数是否在误差范围内相等
are_they_equal = torch.allclose(last_output_from_outputs1, last_layer_from_hn3)
print(f"outputs1 的形状: {outputs1.shape}")
print(f"hn3 的形状: {hn3.shape}")
print("-" * 30)
print(f"outputs1[-1] (最后一个时间步的输出) 的形状: {last_output_from_outputs1.shape}")
print(f"hn3[-1] (最后一层的隐藏状态) 的形状: {last_layer_from_hn3.shape}")
print("-" * 30)
print(f"这两个张量是否相等? {are_they_equal}")
运行结果:
outputs1 的形状: torch.Size([3, 50, 20])
hn3 的形状: torch.Size([4, 50, 20])
------------------------------
outputs1[-1] (最后一个时间步的输出) 的形状: torch.Size([50, 20])
hn3[-1] (最后一层的隐藏状态) 的形状: torch.Size([50, 20])
------------------------------
这两个张量是否相等? True
结果解读:
True
这个结果完美地印证了我们的比喻:
在比赛结束的那一刻(最后一个时间步),摄影师拍下的“最后一棒队员冲线照片”(
outputs1[-1]
),和组委会收到的“团队报告中关于最后一棒队员的描述”(hn3[-1]
),是同一个东西。
1.4 总结表格
特性 | outputs1 | hn3 |
---|---|---|
比喻 | 即时状态报告 (每个检查点的快照) | 最终状态报告 (赛后团队总结) |
关注点 | 时间维度 (序列中的每一步) | 层次维度 (网络中的每一层) |
形状 | (seq_len, batch_size, hidden_size) | (num_layers, batch_size, hidden_size) |
包含信息 | 所有时间步上,最后一层的隐藏状态 | 最后一个时间步上,所有层的隐藏状态 |
典型用途 | 序列标注 (如词性标注)、语音识别 (需要每个时间步的输出) | 序列分类 (如情感分析)、只关心最终结果的场景 |
关键关系 | outputs1[-1] 等于 hn3[-1] | hn3[-1] 等于 outputs1[-1] |
二 RNN中 outputs1
与 hn3
可视化理解图 (ASCII版)
- 再次以“接力赛跑”为喻,用字符画来描绘整个过程。
2.1 整体流程概览
想象一下,我们有50条并行的跑道(代表50个batch)。我们只看其中一条跑道的情况,因为其他49条是完全一样的。
+-------------------+ +-------------------+ +-------------------+
| Time Step 1 | | Time Step 2 | | Time Step 3 |
| (第一棒) | | (第二棒) | | (第三棒) |
| | | | | |
| Input: (10,) | | Input: (10,) | | Input: (10,) |
| (第一段数据) | | (第二段数据) | | (第三段数据) |
| | | | | | | | |
| v | | v | | v |
| +-------------+ | | +-------------+ | | +-------------+ |
| | Layer 1 | |---->| | Layer 1 | |---->| | Layer 1 | |
| | (h_11) | | | | (h_21) | | | | (h_31) | |
| +-------------+ | | +-------------+ | | +-------------+ |
| | | | | | | | |
| v | | v | | v |
| +-------------+ | | +-------------+ | | +-------------+ |
| | Layer 2 | |---->| | Layer 2 | |---->| | Layer 2 | |
| | (h_12) | | | | (h_22) | | | | (h_32) | |
| +-------------+ | | +-------------+ | | +-------------+ |
| | | | | | | | |
| v | | v | | v |
| +-------------+ | | +-------------+ | | +-------------+ |
| | Layer 3 | |---->| | Layer 3 | |---->| | Layer 3 | |
| | (h_13) | | | | (h_23) | | | | (h_33) | |
| +-------------+ | | +-------------+ | | +-------------+ |
| | | | | | | | |
| v | | v | | v |
| +-------------+ | | +-------------+ | | +-------------+ |
| | Layer 4 | |---->| | Layer 4 | |---->| | Layer 4 | |
| | (h_14) | | | | (h_24) | | | | (h_34) | | <----+
| +-------------+ | | +-------------+ | | +-------------+ | |
+-------------------+ +-------------------+ +-------------------+ |^ ^ ^ || | | || (初始隐藏 状态) | | |+-------------------------+-------------------------+--------------------+|| (hn3 的来源)v
图解:
- 横轴是时间: 从左到右,代表你的输入序列的3个时间步。
- 纵轴是网络层: 每个时间步内部,数据从下往上流经4个网络层。
- 箭头
--->
: 代表信息的传递。横向箭头是同一层在不同时间步之间的传递(接力赛中的“交接棒”),纵向箭头是同一时间步内不同层之间的传递。 h_ij
: 代表第i
个时间步,第j
个层的隐藏状态。例如,h_23
就是第2个时间步,第3个层的输出。
2.2 outputs1
是什么?(沿途的终点线摄影)
outputs1
就像是在每一棒的终点线都架设了一台高速摄像机,只拍摄最后一棒运动员(Layer 4)冲线时的照片。
+-------------------+ +-------------------+ +-------------------+
| Time Step 1 | | Time Step 2 | | Time Step 3 |
| | | | | |
| ... (Layers 1-3) | | ... (Layers 1-3) | | ... (Layers 1-3) |
| | | | | | | | |
| v | | v | | v |
| +-------------+ | | +-------------+ | | +-------------+ |
| | Layer 4 | | | | Layer 4 | | | | Layer 4 | |
| | (h_14) | | | | (h_24) | | | | (h_34) | |
| +-------------+ | | +-------------+ | | +-------------+ |
| | | | | | | | |
| v | | v | | v |
| [ Photo 1 ] | | [ Photo 2 ] | | [ Photo 3 ] |
| (h_14 的快照) | | (h_24 的快照) | | (h_34 的快照) |
+-------------------+ +-------------------+ +-------------------+| | ||-------------------------+-------------------------||v+---------------------------------------+| outputs1 (相册) || 形状: (3, 50, 20) || 内容: [Photo1, Photo2, Photo3] || [h_14, h_24, h_34] |+---------------------------------------+
outputs1
总结:
- 内容: 它收集了 所有时间步 的 最后一层 的隐藏状态。
- 形状 (3, 50, 20):
3
: 对应3个时间步(3张照片)。50
: 对应50个batch(50个运动员,每人一本相册)。20
: 对应隐藏层大小(每张照片有20个维度的信息)。
- 用途: 当你需要序列中每一步的输出时,比如给每个词标注词性,或者在语音识别中识别每一帧的声音。
2.3 hn3
是什么?(比赛结束后的团队总结报告)
hn3
就像比赛结束后,教练员记录的最终总结报告。这份报告记录了最后一棒冲线时,所有4名队员(所有层)的最终状态。
+-------------------+ +-------------------+ +-------------------+
| Time Step 1 | | Time Step 2 | | Time Step 3 |
| | | | | |
| ... (Layers 1-3) | | ... (Layers 1-3) | | ... (Layers 1-3) |
| | | | | | | | |
| v | | v | | v |
| +-------------+ | | +-------------+ | | +-------------+ |
| | Layer 1 | | | | Layer 1 | | | | Layer 1 | | --+
| | (h_11) | | | | (h_21) | | | | (h_31) | | |
| +-------------+ | | +-------------+ | | +-------------+ | |
| | | | | | | | | |
| v | | v | | v | |
| +-------------+ | | +-------------+ | | +-------------+ | |
| | Layer 2 | | | | Layer 2 | | | | Layer 2 | | --+
| | (h_12) | | | | (h_22) | | | | (h_32) | | |
| +-------------+ | | +-------------+ | | +-------------+ | |
| | | | | | | | | |
| v | | v | | v | |
| +-------------+ | | +-------------+ | | +-------------+ | |
| | Layer 3 | | | | Layer 3 | | | | Layer 3 | | --+
| | (h_13) | | | | (h_23) | | | | (h_33) | | |
| +-------------+ | | +-------------+ | | +-------------+ | |
| | | | | | | | | |
| v | | v | | v | |
| +-------------+ | | +-------------+ | | +-------------+ | --+
| | Layer 4 | | | | Layer 4 | | | | Layer 4 | |
| | (h_14) | | | | (h_24) | | | | (h_34) | |
| +-------------+ | | +-------------+ | | +-------------+ |
+-------------------+ +-------------------+ +-------------------+|| (比赛结束,收集所有队员状态)v+------------------------------------------+| hn3 (总结报告) || 形状: (4, 50, 20) || 内容: [ Layer1, Layer2, Layer3, Layer4] || [ h_31, h_32, h_33, h_34] |+-------------------------------------------+
hn3
总结:
- 内容: 它收集了 最后一个时间步 的 所有层 的隐藏状态。
- 形状 (4, 50, 20):
4
: 对应4个网络层(报告里有4名队员的总结)。50
: 对应50个batch(50个运动员,每人一份报告)。20
: 对应隐藏层大小(每个队员的状态有20个维度)。
- 用途: 当你只关心整个序列的最终结果时,比如判断一段话的整体情感,或者将一整句话翻译成一个概括性的词。
2.4 核心交汇点 (最重要的关系)
现在,我们把上面两个图的关键部分合在一起看:
(来自 outputs1 的收集)|v
+-------------------+
| Time Step 3 |
| |
| ... (Layers 1-3) |
| | |
| v |
| +-------------+ |
| | Layer 4 | | <------------------------------------+
| | (h_34) | | |
| +-------------+ | |
+-------------------+ || || (来自 hn3 的收集) |v |
+---------------------------------------+ |
| hn3 (总结报告) | |
| 形状: (4, 50, 20) | |
| 内容: [ h_31, h_32, h_33, h_34 ] | |
+---------------------------------------+ |^ || |+--------------------------------------------------+|outputs1[-1] 是 h_34hn3[-1] 也是 h_34所以 outputs1[-1] == hn3[-1]
核心关系总结:
outputs1
的最后一个元素 (outputs1[-1]
),就是 Time Step 3 的 Layer 4 的输出h_34
。hn3
的最后一个元素 (hn3[-1]
),也是 Time Step 3 的 Layer 4 的输出h_34
。- 它们指向的是完全同一个数据!就像一个运动员冲线的瞬间,既被终点摄像机拍下(成为
outputs1
的一部分),也被记录在团队的最终报告里(成为hn3
的一部分)。