AF3 rot_matmul 和 rot_vec_mul函数解读
AlphaFold3 rigid_utils 模块的 rot_matmul 和 rot_vec_mul 函数实现了手动计算
两个旋转矩阵的乘法 A×B 以及矩阵-向量乘法 R×t, 避免了直接用矩阵乘法的AMP(Automatic Mixed Precision)问题。
源代码:
def rot_matmul(
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
b: [*, 3, 3] right multiplicand
Returns:
The product ab
"""
def row_mul(i):
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0]
+ a[..., i, 1] * b[..., 1, 0]
+ a[..., i, 2] * b[..., 2, 0],
a[..., i, 0] * b[..., 0, 1]
+ a[..., i, 1] * b[..., 1, 1]
+ a[..., i, 2] * b[..., 2, 1],
a[..., i, 0] * b[..., 0, 2]
+ a[..., i, 1] * b[..., 1, 2]
+ a[..., i, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack(
[
row_mul(0),
row_mul(1),
row_mul(2),
],
dim=-2
)
def rot_vec_mul(
r: torch.Tensor,
t: torch.Tensor
) -> torch.Tensor:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
t: [*, 3] coordinate tensors
Returns:
[*, 3] rotated coordinates
"""
x, y, z = torch.unbind(t, dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
代码解读:
def rot_matmul(
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
✅