当前位置: 首页 > news >正文

AF3 Rotation类的map_tensor_fn 方法解读

AlphaFold3 rigid_utils 模块Rotation类的 map_tensor_fn方法主要作用是对旋转矩阵或四元数上的最后一维应用一个函数 (fn) ,并返回一个新的 Rotation 对象。

源代码:

    def map_tensor_fn(self, 
        fn: Callable[torch.Tensor, torch.Tensor]
    ) -> Rotation:
        """
            Apply a Tensor -> Tensor function to underlying rotation tensors,
            mapping over the rotation dimension(s). Can be used e.g. to sum out
            a one-hot batch dimension.

            Args:
                fn:
                    A Tensor -> Tensor function to be mapped over the Rotation 
            Returns:
                The transformed Rotation object
        """ 
        if(self._rot_mats is not None):
            rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
            rot_mats = torch.stack(
                list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
            )
            rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(self._quats is not None):
            quats = torch.stack(
                list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
            )
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")

代码解读:

方法签名
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:
  • fn:接收一个 Tensor,返回一个 Tensor,典型用途是对旋转的某个维度做变换,比如求和、加权平均等。

  • 返回值:一个新的 Rotation 对象,里面装着变换后的旋转矩阵 (rot_mats) 或四元数 (quats)。

处理旋转矩阵 (_rot_mats)

如果 self._rot_mats 存在,就走这条分支:

if self._rot_mats is not None:
    # 把 (batch_size, ..., 3, 3) reshape 成 (batch_size, ..., 9)
    rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))

✅ 解释
view() 是为了把 3x3 的旋转矩阵摊平成 9 维向量,方便对最后一维应用函数。

rot_mats = torch.stack(
    list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)

✅ 解释

  1. torch.unbind():沿最后一维解开成 9 个独立的张量。

  2. map(fn, ...):对每个解开的张量应用 fn

  3. torch.stack():把变换后的 9 个张量重新堆叠回去。

注: torch.unbind 维度 -1 ,torch.stack 维度 +1, 并且都处理相同的维度(-1)。

rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)

✅ 解释
把 9 维向量重新 reshaped 成 (3, 3) 矩阵,并用它创建一个新的 Rotation 对象。

处理四元数 (_quats)

如果矩阵不存在,走四元数分支:

elif self._quats is not None:
    quats = torch.stack(
        list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
    )
    return Rotation(rot_mats=None, quats=quats, normalize_quats=False)

✅ 解释

  • 逻辑和矩阵类似,先 unbind() 分解四元数的最后一维,对每个部分应用 fn(),再 stack() 堆叠回来。

  • 创建新 Rotation 对象时加了 normalize_quats=False,说明这一步不需要再归一化。

 防错处理

如果两个旋转表示都没有,抛出异常:

else:
    raise ValueError("Both rotations are None")

总结

map_tensor_fn() 是一种 高阶函数,它能灵活地对旋转矩阵或四元数的最后一维执行各种操作(比如求和、加权、归一化、剪裁等)。

核心逻辑:

  • 矩阵路径 → reshape(9维) → 分解 → 应用函数 → 堆叠 → 恢复3x3

  • 四元数路径 → 分解 → 应用函数 → 堆叠

相关文章:

  • Oracle 23ai Vector Search 系列之1 架构基础
  • RT-Thread CI编译产物artifacts自动上传功能介绍
  • python socket模块学习记录
  • KPMG 与 SAP Joule:引领 AI 驱动咨询的新时代
  • 什么情况下spring的事务会失效
  • 私域电商的进化逻辑与技术赋能:基于开源AI大模型与S2B2C商城的创新融合研究
  • C#设计模式快速回顾
  • C语言- 工厂模式详解与实践
  • 常见中间件漏洞攻略-Apache篇
  • datetime“陷阱”与救赎:扒“时间差值”证道
  • Pytorch实现之对称卷积神经网络结构实现超分辨率
  • Pytorch深度学习教程_9_nn模块构建神经网络
  • 数据结构——哈夫曼编码、哈夫曼树
  • SAP-ABAP:SAP BW模块架构与实战应用详解
  • 使用Python将视频转化为gif
  • AF3 Rotation 类解读
  • stc8g1k08a+cd4017红绿灯
  • 嵌入式学习(31)-Lora模块A39C-T400A30D1a
  • 数据结构5(初):续写排序
  • HarmonyOS NEXT(九) :图形渲染体系
  • 浙能集团原董事长童亚辉被查,还是杭州市书法家协会主席
  • “11+2”复式票,宝山购彩者领走大乐透1170万头奖
  • 睡觉总做梦是睡眠质量差?梦到这些事,才要小心
  • 小米SU7 Ultra风波升级:数百名车主要求退车,车主喊话雷军“保持真诚”
  • 英国首相斯塔默住所起火,警方紧急调查情况
  • 李公明 | 一周画记:印巴交火会否升级为第四次印巴战争?