WITRAN_2DPSGMU_Encoder 类中,门机制
WITRAN_2DPSGMU_Encoder
类中的门机制详解
在 WITRAN_2DPSGMU_Encoder
类中,门机制是核心部分,类似于 LSTM 或 GRU 的门控机制,用于控制隐藏状态的更新和输出。以下是对门机制的详细解析。
1. 门机制的作用
门机制的主要作用是:
-
控制信息流动:
- 决定当前时间步的输入信息和上一时间步的隐藏状态如何结合。
- 通过更新门、输入门和输出门,选择性地保留或丢弃信息。
-
更新隐藏状态:
- 根据门控信号,更新行隐藏状态(
hidden_slice_row
)和列隐藏状态(hidden_slice_col
)。
- 根据门控信号,更新行隐藏状态(
-
生成输出:
- 将更新后的隐藏状态拼接,作为当前时间步的输出。
2. 门机制的组成
门机制由以下几部分组成:
(1) 门控信号的生成
gate = self.linear(torch.cat([hidden_slice_row, hidden_slice_col, a[:, slice, :]], dim=-1), W, B, batch_size, slice, Water2sea_slice_num)
-
输入:
hidden_slice_row
:行隐藏状态,形状为[128, 32]
。hidden_slice_col
:列隐藏状态,形状为[128, 32]
。a[:, slice, :]
:当前时间步的输入,形状为[128, 11]
。- 拼接后,输入形状为
[128, 75]
。
-
线性变换:
- 权重矩阵
W
的形状为[192, 75]
。 - 偏置向量
B
的形状为[192]
。 - 输出
gate
的形状为[128, 192]
。
- 权重矩阵
(2) 分割门控信号
sigmod_gate, tanh_gate = torch.split(gate, 4 * self.hidden_size, dim=-1)
- 分割结果:
sigmod_gate
:前128
列,用于生成更新门和输出门。tanh_gate
:后64
列,用于生成输入门。
(3) 激活函数处理
-
Sigmoid 激活:
sigmod_gate = torch.sigmoid(sigmod_gate)
- 将
sigmod_gate
的值映射到[0, 1]
区间。 - 用于生成更新门和输出门的开关信号。
- 将
-
Tanh 激活:
tanh_gate = torch.tanh(tanh_gate)
- 将
tanh_gate
的值映射到[-1, 1]
区间。 - 用于生成输入门的候选值。
- 将
(4) 分割门控信号为具体的门
update_gate_row, output_gate_row, update_gate_col, output_gate_col = sigmod_gate.chunk(4, dim=-1)
input_gate_row, input_gate_col = tanh_gate.chunk(2, dim=-1)
-
更新门:
update_gate_row
:更新行隐藏状态的门控信号,形状为[128, 32]
。update_gate_col
:更新列隐藏状态的门控信号,形状为[128, 32]
。
-
输出门:
output_gate_row
:输出行隐藏状态的门控信号,形状为[128, 32]
。output_gate_col
:输出列隐藏状态的门控信号,形状为[128, 32]
。
-
输入门:
input_gate_row
:用于更新行隐藏状态的输入门信号,形状为[128, 32]
。input_gate_col
:用于更新列隐藏状态的输入门信号,形状为[128, 32]
。
3. 隐藏状态的更新
(1) 更新行隐藏状态
hidden_slice_row = torch.tanh((1-update_gate_row)*hidden_slice_row + update_gate_row*input_gate_row) * output_gate_row
-
计算过程:
(1 - update_gate_row) * hidden_slice_row
:- 保留上一时间步的行隐藏状态。
update_gate_row * input_gate_row
:- 引入当前时间步的输入信息。
- 相加后通过
torch.tanh
激活:- 将结果映射到
[-1, 1]
区间。
- 将结果映射到
- 乘以
output_gate_row
:- 控制隐藏状态的输出强度。
-
结果:
- 更新后的行隐藏状态
hidden_slice_row
,形状为[128, 32]
。
- 更新后的行隐藏状态
(2) 更新列隐藏状态
hidden_slice_col = torch.tanh((1-update_gate_col)*hidden_slice_col + update_gate_col*input_gate_col) * output_gate_col
-
计算过程:
- 与更新行隐藏状态的逻辑相同,但作用于列隐藏状态。
-
结果:
- 更新后的列隐藏状态
hidden_slice_col
,形状为[128, 32]
。
- 更新后的列隐藏状态
4. 输出的生成
output_slice = torch.cat([hidden_slice_row, hidden_slice_col], dim=-1)
-
拼接隐藏状态:
- 将行隐藏状态
hidden_slice_row
和列隐藏状态hidden_slice_col
在最后一个维度上拼接。 - 输出形状为
[128, 64]
。
- 将行隐藏状态
-
保存输出:
output_all_slice_list.append(output_slice)
- 将当前时间步的输出保存到
output_all_slice_list
中。
- 将当前时间步的输出保存到
5. 门机制的核心逻辑总结
-
门控信号的生成:
- 通过线性变换生成门控信号
gate
,并分割为更新门、输出门和输入门。
- 通过线性变换生成门控信号
-
隐藏状态的更新:
- 使用更新门和输入门结合上一时间步的隐藏状态,生成新的隐藏状态。
- 使用输出门控制隐藏状态的输出强度。
-
输出的生成:
- 将行隐藏状态和列隐藏状态拼接,作为当前时间步的输出。
6. 图示化表示
输入:
hidden_slice_row [128, 32]
hidden_slice_col [128, 32]
当前时间步输入 a[:, slice, :] [128, 11]
拼接后输入 [128, 75]
↓
线性变换:
gate = self.linear(...) → [128, 192]
↓
分割 gate:
sigmod_gate [128, 128] → 更新门、输出门
tanh_gate [128, 64] → 输入门
↓
激活函数:
Sigmoid → sigmod_gate
Tanh → tanh_gate
↓
分割门控信号:
更新门:update_gate_row, update_gate_col
输出门:output_gate_row, output_gate_col
输入门:input_gate_row, input_gate_col
↓
更新隐藏状态:
hidden_slice_row → 更新行隐藏状态
hidden_slice_col → 更新列隐藏状态
↓
生成输出:
output_slice = torch.cat([hidden_slice_row, hidden_slice_col], dim=-1) → [128, 64]
7. 总结
- 更新门:控制上一时间步的隐藏状态保留多少。
- 输入门:控制当前时间步的输入信息引入多少。
- 输出门:控制隐藏状态的输出强度。
- 最终输出:将行隐藏状态和列隐藏状态拼接,作为当前时间步的输出。
这种门机制类似于 LSTM,但针对二维时间序列数据进行了扩展,分别更新行和列的隐藏状态,从而捕获时间序列的复杂依赖关系。