当前位置: 首页 > news >正文

AF3 rot_matmul 和 rot_vec_mul函数解读

AlphaFold3  rigid_utils 模块的 rot_matmul 和 rot_vec_mul 函数实现了手动计算两个旋转矩阵的乘法 A×B 以及矩阵-向量乘法 R×t, 避免了直接用矩阵乘法的AMP(Automatic Mixed Precision)问题。

源代码:

def rot_matmul(
    a: torch.Tensor, 
    b: torch.Tensor
) -> torch.Tensor:
    """
        Performs matrix multiplication of two rotation matrix tensors. Written
        out by hand to avoid AMP downcasting.

        Args:
            a: [*, 3, 3] left multiplicand
            b: [*, 3, 3] right multiplicand
        Returns:
            The product ab
    """
    def row_mul(i):
        return torch.stack(
            [
                a[..., i, 0] * b[..., 0, 0]
                + a[..., i, 1] * b[..., 1, 0]
                + a[..., i, 2] * b[..., 2, 0],
                a[..., i, 0] * b[..., 0, 1]
                + a[..., i, 1] * b[..., 1, 1]
                + a[..., i, 2] * b[..., 2, 1],
                a[..., i, 0] * b[..., 0, 2]
                + a[..., i, 1] * b[..., 1, 2]
                + a[..., i, 2] * b[..., 2, 2],
            ],
            dim=-1,
        )

    return torch.stack(
        [
            row_mul(0), 
            row_mul(1), 
            row_mul(2),
        ], 
        dim=-2
    )


def rot_vec_mul(
    r: torch.Tensor, 
    t: torch.Tensor
) -> torch.Tensor:
    """
        Applies a rotation to a vector. Written out by hand to avoid transfer
        to avoid AMP downcasting.

        Args:
            r: [*, 3, 3] rotation matrices
            t: [*, 3] coordinate tensors
        Returns:
            [*, 3] rotated coordinates
    """
    x, y, z = torch.unbind(t, dim=-1)
    return torch.stack(
        [
            r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
            r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
            r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
        ],
        dim=-1,
    )

代码解读:

def rot_matmul(
    a: torch.Tensor, 
    b: torch.Tensor
) -> torch.Tensor:

相关文章:

  • 【算法学习之路】13.BFS
  • 大语言模型进化论:从文本理解到多模态认知的革命之路
  • 高斯数据库-WDR Snapshot生成性能报告
  • 【商城实战(56)】商城数据生命线:恢复流程与演练全解析
  • datawhale组队学习--大语言模型—task4:Transformer架构及详细配置
  • 7. 二叉树****
  • Proteus 使用入门指南
  • Powershell WSL .wslconfig 实现与宿主机的网络互通
  • 0322-数据库、前后端
  • SSE详解面试常考问题详解
  • 基于 Vue 3 的PDF和Excel导出
  • Ubuntu22.04通过DKMS包安装Intel WiFi系列适配器(网卡驱动)
  • JavaScript 中 “new Map()”的使用
  • AI语音聊天机器人APP(使用webrtc、语音识别、TTL、langchain、大语语模型、uniapp)
  • 用坦克比喻理解类的封装性
  • 二叉树的层序遍历||(107)
  • 用 pytorch 从零开始创建大语言模型(六):对分类进行微调
  • C++中,构造函数和析构函数
  • 初识HTTP
  • 一维前缀和与二维前缀和的详细用法和介绍
  • 阿坝州委书记徐芝文已任四川省政府党组成员
  • 经济日报整版聚焦:上海构建法治化营商环境,交出高分答卷
  • 共情场域与可持续发展——关于博物馆、美术馆运营的新思考
  • 习近平出席中拉论坛第四届部长级会议开幕式并发表主旨讲话
  • 石家庄推动城市能级与民生福祉并进
  • 中方发布会:中美经贸高层会谈取得了实质性进展,达成了重要共识