当前位置: 首页 > news >正文

AF3 Rotation 类解读

Rotation 类(rigid_utils 模块)是 AlphaFold3 中用于 3D旋转 的核心组件,支持两种旋转表示: 1️⃣ 旋转矩阵 (3x3)
2️⃣ 四元数 (quaternion, 4元向量)

👉 设计目标

  • 允许灵活选择 旋转矩阵 或 四元数

  • 封装了常用的 旋转操作(组合、逆旋转、应用到点上等)

  • 像 torch.Tensor 一样,支持索引、拼接、广播等操作

源代码:

class Rotation:
    """
        A 3D rotation. Depending on how the object is initialized, the
        rotation is represented by either a rotation matrix or a
        quaternion, though both formats are made available by helper functions.
        To simplify gradient computation, the underlying format of the
        rotation cannot be changed in-place. Like Rigid, the class is designed
        to mimic the behavior of a torch Tensor, almost as if each Rotation
        object were a tensor of rotations, in one format or another.
    """
    def __init__(self,
        rot_mats: Optional[torch.Tensor] = None,
        quats: Optional[torch.Tensor] = None,
        normalize_quats: bool = True,
    ):
        """
            Args:
                rot_mats:
                    A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
                    quats
                quats:
                    A [*, 4] quaternion. Mutually exclusive with rot_mats. If
                    normalize_quats is not True, must be a unit quaternion
                normalize_quats:
                    If quats is specified, whether to normalize quats
        """
        if((rot_mats is None and quats is None) or 
            (rot_mats is not None and quats is not None)):
            raise ValueError("Exactly one input argument must be specified")

        if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or 
            (quats is not None and quats.shape[-1] != 4)):
            raise ValueError(
                "Incorrectly shaped rotation matrix or quaternion"
            )

        # Force full-precision
        if(quats is not None):
            quats = quats.to(dtype=torch.float32)
        if(rot_mats is not None):
            rot_mats = rot_mats.to(dtype=torch.float32)

        if(quats is not None and normalize_quats):
            quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)

        self._rot_mats = rot_mats
        self._quats = quats

    @staticmethod
    def identity(
        shape,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rotation:
        """
            Returns an identity Rotation.

            Args:
                shape:
                    The "shape" of the resulting Rotation object. See documentation
                    for the shape property
                dtype:
                    The torch dtype for the rotation
                device:
                    The torch device for the new rotation
                requires_grad:
                    Whether the underlying tensors in the new rotation object
                    should require gradient computation
                fmt:
                    One of "quat" or "rot_mat". Determines the underlying format
                    of the new object's rotation 
            Returns:
                A new identity rotation
        """
        if(fmt == "rot_mat"):
            rot_mats = identity_rot_mats(
                shape, dtype, device, requires_grad,
            )
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(fmt == "quat"):
            quats = identity_quats(shape, dtype, device, requires_grad)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError(f"Invalid format: f{fmt}")

    # Magic methods

    def __getitem__(self, index: Any) -> Rotation:
        """
            Allows torch-style indexing over the virtual shape of the rotation
            object. See documentation for the shape property.

            Args:
                index:
                    A torch index. E.g. (1, 3, 2), or (slice(None,))
            Returns:
                The indexed rotation
        """
        if type(index) != tuple:
            index = (index,)

        if(self._rot_mats is not None):
            rot_mats = self._rot_mats[index + (slice(None), slice(None))]
            return Rotation(rot_mats=rot_mats)
        elif(self._quats is not None):
            quats = self._quats[index + (slice(None),)]
            return Rotation(quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")

    def __mul__(self,
        right: torch.Tensor,
    ) -> Rotation:
        """
            Pointwise left multiplication of the rotation with a tensor. Can be
            used to e.g. mask the Ro

相关文章:

  • stc8g1k08a+cd4017红绿灯
  • 嵌入式学习(31)-Lora模块A39C-T400A30D1a
  • 数据结构5(初):续写排序
  • HarmonyOS NEXT(九) :图形渲染体系
  • 嵌入式八股文学习——STL相关内容学习
  • 测试专项4:AI算法测试在测试行业中,该如何定位自己自述
  • 地理编码/经纬度解析/经纬度地址转换接口如何用JAVA对接
  • ui_auto_study(持续更新)
  • 当今前沿科技:改变世界的最新技术趋势
  • 【Spring】深入理解 Spring 事务管理
  • VScode
  • Java 中的多线程:核心概念与应用场景
  • 机器学习——KNN数据均一化
  • Qt文件管理系统
  • Spring AI相关的面试题
  • 算法如何测试,如果数据量很大怎么办?
  • 逆波兰表达式
  • [Lc17_多源 BFS_最短路] 矩阵 | 飞地的数量 | 地图中的最高点 | 地图分析
  • 串口接收不到数据,串口RX配置(f407),f103和f407的区别
  • Linux第二章第三章练习
  • 占地57亩的“潮汕豪宅”面临强制拆除:曾被实施没收,8年间举行5次听证会
  • 4月份全国企业销售收入同比增长4.3%
  • 媒体:“西北大学副校长范代娣成陕西首富”系乌龙,但她的人生如同开挂
  • 内塔尼亚胡:以军将在未来几天“全力进入”加沙
  • 彭丽媛同巴西总统夫人罗桑热拉参观中国国家大剧院
  • 某博主遭勒索后自杀系自导自演,成都警方立案调查