当前位置: 首页 > news >正文

haiku实现三角乘法模块

三角乘法(TriangleMultiplication)是作为一种更对称、更便宜的三角注意力(TriangleAttention)替代模块。

import jax
import haiku
import jax.numpy as jnp

def _layer_norm(axis=-1, name='layer_norm'):
  return common_modules.LayerNorm(
      axis=axis,
      create_scale=True,
      create_offset=True,
      eps=1e-5,
      use_fast_variance=True,
      scale_init=hk.initializers.Constant(1.),
      offset_init=hk.initializers.Constant(0.),
      param_axis=axis,
      name=name)


class TriangleMultiplication(hk.Module):
  """Triangle multiplication layer ("outgoing" or "incoming").

  Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"
  Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"
  """

  def __init__(self, config, global_config, name='triangle_multiplication'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config

  def __call__(self, left_act, left_mask, is_training=True):
    """Builds TriangleMultiplication module.

    Arguments:
      left_act: Pair activations, shape [N_res, N_res, c_z]
      left_mask: Pair mask, shape [N_res, N_res].
      is_training: Whether the module is in training mode.

    Returns:
      Outputs, same shape/type as left_act.
    """
    del is_training

    if self.config.fuse_projection_weights:
      return self._fused_triangle_multiplication(left_act, left_mask)
    else:
      return self._triangle_multiplication(left_act, left_mask)

  # @hk.transparent 是 Haiku 中的函数修饰器,用于标记函数为透明模式。
  # 透明模式用于在神经网络模块内共享参数。
  @hk.transparent
  def _triangle_multiplication(self, left_act, left_mask):
    """Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
    c = self.config
    gc = self.global_config

    mask = left_mask[..., None]

    act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
                       name='layer_norm_input')(left_act)
    input_act = act

    left_projection = common_modules.Linear(
        c.num_intermediate_channel,
        name='left_projection')
    left_proj_act = mask * left_projection(act)

    right_projection = common_modules.Linear(
        c.num_intermediate_channel,
        name='right_projection')
    right_proj_act = mask * right_projection(act)

    left_gate_values = jax.nn.sigmoid(common_modules.Linear(
        c.num_intermediate_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='left_gate')(act))

    right_gate_values = jax.nn.sigmoid(common_modules.Linear(
        c.num_intermediate_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='right_gate')(act))

    left_proj_act *= left_gate_values
    right_proj_act *= right_gate_values

    # "Outgoing" edges equation: 'ikc,jkc->ijc'
    # "Incoming" edges equation: 'kjc,kic->ijc'
    # Note on the Suppl. Alg. 11 & 12 notation:
    # For the "outgoing" edges, a = left_proj_act and b = right_proj_act
    # For the "incoming" edges, it's swapped:
    #   b = left_proj_act and a = right_proj_act
    act = jnp.einsum(c.equation, left_proj_act, right_proj_act)

    act = common_modules.LayerNorm(
        axis=[-1],
        create_scale=True,
        create_offset=True,
        name='center_layer_norm')(
            act)

    output_channel = int(input_act.shape[-1])

    act = common_modules.Linear(
        output_channel,
        initializer=utils.final_init(gc),
        name='output_projection')(act)

    gate_values = jax.nn.sigmoid(common_modules.Linear(
        output_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='gating_linear')(input_act))
    act *= gate_values

    return act

  @hk.transparent
  def _fused_triangle_multiplication(self, left_act, left_mask):
    """TriangleMultiplication with fused projection weights."""
    mask = left_mask[..., None]
    c = self.config
    gc = self.global_config

    left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)

    # Both left and right projections are fused into projection.
    projection = common_modules.Linear(
        2*c.num_intermediate_channel, name='projection')
    proj_act = mask * projection(left_act)

    # Both left + right gate are fused into gate_values.
    gate_values = common_modules.Linear(
        2 * c.num_intermediate_channel,
        name='gate',
        bias_init=1.,
        initializer=utils.final_init(gc))(left_act)
    proj_act *= jax.nn.sigmoid(gate_values)

    left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
    right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
    act = jnp.einsum(c.equation, left_proj_act, right_proj_act)

    act = _layer_norm(axis=-1, name='center_norm')(act)

    output_channel = int(left_act.shape[-1])

    act = common_modules.Linear(
        output_channel,
        initializer=utils.final_init(gc),
        name='output_projection')(act)

    gate_values = common_modules.Linear(
        output_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='gating_linear')(left_act)
    act *= jax.nn.sigmoid(gate_values)

    return act


 

相关文章:

  • 说一下mysql的锁
  • Configure Virtual Serial Port Driver串口模拟器VSPD
  • 【手把手带你玩转MyBatis】基础篇:掌握事务管理,确保数据操作的原子性与一致性
  • 【JVM调优系列】如何导出堆内存文件
  • 微信小程序支付之V2支付
  • QT上位机开发(进度条操作)
  • 2024.1.14
  • 【驱动】TI AM437x(内核调试-06):网卡(PHY和MAC)、七层OSI
  • C++笔记
  • springcloud gateway动态路由
  • Erlang/OTP中的日志与事件处理(一)
  • vue2使用electron以及打包配置
  • 【小白专用】C# 连接 MySQL 数据库
  • K8S 日志方案
  • webpack的性能优化(二)——减少打包体积
  • Baumer工业相机堡盟工业相机如何使用OpenCV实现相机图像的显示(C#)
  • Pandas实战100例 | 案例 13: 数据分类 - 使用 `cut` 对数值进行分箱
  • 软件测试|SQLAlchemy环境安装与基础使用
  • Ftrans飞驰云联荣获“CSA 2023安全创新奖”
  • Spark详解
  • 财政部党组召开2025年巡视工作会议暨第一轮巡视动员部署会
  • 中国海警舰艇编队5月14日在我钓鱼岛领海巡航
  • 沈阳卫健委通报“健康证”办理乱象:涉事医院已被立案查处
  • 国际博物馆日中国主会场确定,北京将展“看·见殷商”等展览
  • 十三届全国政协经济委员会副主任张效廉被决定逮捕
  • 中美经贸中方牵头人、国务院副总理何立峰出席新闻发布会表示:中美达成重要共识,会谈取得实质性进展