haiku实现三角乘法模块
三角乘法(TriangleMultiplication)是作为一种更对称、更便宜的三角注意力(TriangleAttention)替代模块。
import jax
import haiku
import jax.numpy as jnp
def _layer_norm(axis=-1, name='layer_norm'):
return common_modules.LayerNorm(
axis=axis,
create_scale=True,
create_offset=True,
eps=1e-5,
use_fast_variance=True,
scale_init=hk.initializers.Constant(1.),
offset_init=hk.initializers.Constant(0.),
param_axis=axis,
name=name)
class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming").
Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"
Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"
"""
def __init__(self, config, global_config, name='triangle_multiplication'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, left_act, left_mask, is_training=True):
"""Builds TriangleMultiplication module.
Arguments:
left_act: Pair activations, shape [N_res, N_res, c_z]
left_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as left_act.
"""
del is_training
if self.config.fuse_projection_weights:
return self._fused_triangle_multiplication(left_act, left_mask)
else:
return self._triangle_multiplication(left_act, left_mask)
# @hk.transparent 是 Haiku 中的函数修饰器,用于标记函数为透明模式。
# 透明模式用于在神经网络模块内共享参数。
@hk.transparent
def _triangle_multiplication(self, left_act, left_mask):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c = self.config
gc = self.global_config
mask = left_mask[..., None]
act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(left_act)
input_act = act
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)
right_projection = common_modules.Linear(
c.num_intermediate_channel,
name='right_projection')
right_proj_act = mask * right_projection(act)
left_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='left_gate')(act))
right_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='right_gate')(act))
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
# "Outgoing" edges equation: 'ikc,jkc->ijc'
# "Incoming" edges equation: 'kjc,kic->ijc'
# Note on the Suppl. Alg. 11 & 12 notation:
# For the "outgoing" edges, a = left_proj_act and b = right_proj_act
# For the "incoming" edges, it's swapped:
# b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='center_layer_norm')(
act)
output_channel = int(input_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = jax.nn.sigmoid(common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(input_act))
act *= gate_values
return act
@hk.transparent
def _fused_triangle_multiplication(self, left_act, left_mask):
"""TriangleMultiplication with fused projection weights."""
mask = left_mask[..., None]
c = self.config
gc = self.global_config
left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
# Both left and right projections are fused into projection.
projection = common_modules.Linear(
2*c.num_intermediate_channel, name='projection')
proj_act = mask * projection(left_act)
# Both left + right gate are fused into gate_values.
gate_values = common_modules.Linear(
2 * c.num_intermediate_channel,
name='gate',
bias_init=1.,
initializer=utils.final_init(gc))(left_act)
proj_act *= jax.nn.sigmoid(gate_values)
left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = _layer_norm(axis=-1, name='center_norm')(act)
output_channel = int(left_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(left_act)
act *= jax.nn.sigmoid(gate_values)
return act