当前位置: 首页 > news >正文

WITRAN_2DPSGMU_Encoder 类中,门机制

WITRAN_2DPSGMU_Encoder 类中的门机制详解

WITRAN_2DPSGMU_Encoder 类中,门机制是核心部分,类似于 LSTM 或 GRU 的门控机制,用于控制隐藏状态的更新和输出。以下是对门机制的详细解析。


1. 门机制的作用

门机制的主要作用是:

  1. 控制信息流动

    • 决定当前时间步的输入信息和上一时间步的隐藏状态如何结合。
    • 通过更新门、输入门和输出门,选择性地保留或丢弃信息。
  2. 更新隐藏状态

    • 根据门控信号,更新行隐藏状态(hidden_slice_row)和列隐藏状态(hidden_slice_col)。
  3. 生成输出

    • 将更新后的隐藏状态拼接,作为当前时间步的输出。

2. 门机制的组成

门机制由以下几部分组成:

(1) 门控信号的生成
gate = self.linear(torch.cat([hidden_slice_row, hidden_slice_col, a[:, slice, :]], dim=-1), W, B, batch_size, slice, Water2sea_slice_num)
  • 输入

    • hidden_slice_row:行隐藏状态,形状为 [128, 32]
    • hidden_slice_col:列隐藏状态,形状为 [128, 32]
    • a[:, slice, :]:当前时间步的输入,形状为 [128, 11]
    • 拼接后,输入形状为 [128, 75]
  • 线性变换

    • 权重矩阵 W 的形状为 [192, 75]
    • 偏置向量 B 的形状为 [192]
    • 输出 gate 的形状为 [128, 192]

(2) 分割门控信号
sigmod_gate, tanh_gate = torch.split(gate, 4 * self.hidden_size, dim=-1)
  • 分割结果
    • sigmod_gate:前 128 列,用于生成更新门和输出门。
    • tanh_gate:后 64 列,用于生成输入门。

(3) 激活函数处理
  • Sigmoid 激活

    sigmod_gate = torch.sigmoid(sigmod_gate)
    
    • sigmod_gate 的值映射到 [0, 1] 区间。
    • 用于生成更新门和输出门的开关信号。
  • Tanh 激活

    tanh_gate = torch.tanh(tanh_gate)
    
    • tanh_gate 的值映射到 [-1, 1] 区间。
    • 用于生成输入门的候选值。

(4) 分割门控信号为具体的门
update_gate_row, output_gate_row, update_gate_col, output_gate_col = sigmod_gate.chunk(4, dim=-1)
input_gate_row, input_gate_col = tanh_gate.chunk(2, dim=-1)
  • 更新门

    • update_gate_row:更新行隐藏状态的门控信号,形状为 [128, 32]
    • update_gate_col:更新列隐藏状态的门控信号,形状为 [128, 32]
  • 输出门

    • output_gate_row:输出行隐藏状态的门控信号,形状为 [128, 32]
    • output_gate_col:输出列隐藏状态的门控信号,形状为 [128, 32]
  • 输入门

    • input_gate_row:用于更新行隐藏状态的输入门信号,形状为 [128, 32]
    • input_gate_col:用于更新列隐藏状态的输入门信号,形状为 [128, 32]

3. 隐藏状态的更新

(1) 更新行隐藏状态
hidden_slice_row = torch.tanh((1-update_gate_row)*hidden_slice_row + update_gate_row*input_gate_row) * output_gate_row
  • 计算过程

    1. (1 - update_gate_row) * hidden_slice_row
      • 保留上一时间步的行隐藏状态。
    2. update_gate_row * input_gate_row
      • 引入当前时间步的输入信息。
    3. 相加后通过 torch.tanh 激活:
      • 将结果映射到 [-1, 1] 区间。
    4. 乘以 output_gate_row
      • 控制隐藏状态的输出强度。
  • 结果

    • 更新后的行隐藏状态 hidden_slice_row,形状为 [128, 32]

(2) 更新列隐藏状态
hidden_slice_col = torch.tanh((1-update_gate_col)*hidden_slice_col + update_gate_col*input_gate_col) * output_gate_col
  • 计算过程

    • 与更新行隐藏状态的逻辑相同,但作用于列隐藏状态。
  • 结果

    • 更新后的列隐藏状态 hidden_slice_col,形状为 [128, 32]

4. 输出的生成

output_slice = torch.cat([hidden_slice_row, hidden_slice_col], dim=-1)
  • 拼接隐藏状态

    • 将行隐藏状态 hidden_slice_row 和列隐藏状态 hidden_slice_col 在最后一个维度上拼接。
    • 输出形状为 [128, 64]
  • 保存输出

    output_all_slice_list.append(output_slice)
    
    • 将当前时间步的输出保存到 output_all_slice_list 中。

5. 门机制的核心逻辑总结

  1. 门控信号的生成

    • 通过线性变换生成门控信号 gate,并分割为更新门、输出门和输入门。
  2. 隐藏状态的更新

    • 使用更新门和输入门结合上一时间步的隐藏状态,生成新的隐藏状态。
    • 使用输出门控制隐藏状态的输出强度。
  3. 输出的生成

    • 将行隐藏状态和列隐藏状态拼接,作为当前时间步的输出。

6. 图示化表示

输入:
    hidden_slice_row [128, 32]
    hidden_slice_col [128, 32]
    当前时间步输入 a[:, slice, :] [128, 11]
    拼接后输入 [128, 75]
        ↓
线性变换:
    gate = self.linear(...) → [128, 192]
        ↓
分割 gate:
    sigmod_gate [128, 128] → 更新门、输出门
    tanh_gate [128, 64] → 输入门
        ↓
激活函数:
    Sigmoid → sigmod_gate
    Tanh → tanh_gate
        ↓
分割门控信号:
    更新门:update_gate_row, update_gate_col
    输出门:output_gate_row, output_gate_col
    输入门:input_gate_row, input_gate_col
        ↓
更新隐藏状态:
    hidden_slice_row → 更新行隐藏状态
    hidden_slice_col → 更新列隐藏状态
        ↓
生成输出:
    output_slice = torch.cat([hidden_slice_row, hidden_slice_col], dim=-1) → [128, 64]

7. 总结

  • 更新门:控制上一时间步的隐藏状态保留多少。
  • 输入门:控制当前时间步的输入信息引入多少。
  • 输出门:控制隐藏状态的输出强度。
  • 最终输出:将行隐藏状态和列隐藏状态拼接,作为当前时间步的输出。

这种门机制类似于 LSTM,但针对二维时间序列数据进行了扩展,分别更新行和列的隐藏状态,从而捕获时间序列的复杂依赖关系。

相关文章:

  • 高级语言调用C接口(四)结构体(2)-Python
  • 如何在本地修改 Git 项目的远程仓库地址
  • 智算启新篇 安全筑新基 ——中国移动举办智算基础设施及安全分论坛
  • C++ 智能指针底层逻辑揭秘:优化内存管理的核心技术解读
  • Java 中常见的数据结构
  • 3、组件:魔法傀儡的诞生——React 19 组件化开发全解析
  • 【Python爬虫】详细入门指南
  • UNet深度学习实战遥感航拍图像语义分割
  • Java雪花算法
  • RabbitMQ的应用
  • mysql和mongodb
  • React 之 Redux 第三十二节 Redux 常用API及HOOKS,以及Redux Toolkit核心API使用详解
  • 62. 评论日记
  • java 实现文件编码检测的多种方式
  • Podman技术深度解剖:架构、原理与核心特性解析
  • cocos Spine资源及加载
  • JavaScript Map 对象深度解剖
  • HarmonyOS 第2章 Ability的开发,鸿蒙HarmonyOS 应用开发入门
  • 开源FMC 4路千兆网模块
  • Git 基本使用
  • 4月企业新发放贷款利率处于历史低位
  • “女硕士失踪13年生两孩”案进入审查起诉阶段,哥哥:妹妹精神状态好转
  • 国台办:实现祖国完全统一是大势所趋、大义所在、民心所向
  • 一海南救护车在西藏无任务拉警笛开道,墨脱警方:已处罚教育
  • 英国首相斯塔默住所起火,警方紧急调查情况
  • 区域国别学视域下的东亚文化交涉