AF3 identity_rot_mats函数解读
AlphaFold3 rigid_utils 模块的 identity_rot_mats
函数的目标是创建指定批次维度的旋转不变矩阵(即 3x3 的恒等矩阵),它适合多批次数据处理,比如 AlphaFold3 里的蛋白质多体建模。
旋转不变矩阵 就是:
它表示 不旋转,是所有旋转矩阵的起点。
源代码:
@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
rots = rots.contiguous()
return