AF3 rot_to_quat函数解读
AlphaFold3 rigid_utils 模块的
rot_to_quat
函数的功能是把旋转矩阵转换为四元数,利用K矩阵提取最大特征值对应的特征向量,即为四元数。
源代码:
def rot_to_quat(
rot: torch.Tensor,
):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
代码解读:
函数入口
def rot_to_quat(rot: torch.Tensor):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
检查输入维度,确保输入是 3×3 旋转矩阵。
提取矩阵元素
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
拆解矩阵元素,提取出 Rij 的各个元素。
构造矩阵 K
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
构造对称矩阵 K,对应前面推导出的矩阵公式。
标准化矩阵
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
将矩阵 K 按公式标准化,确保数值稳定。
求最大特征值和特征向量
_, vectors = torch.linalg.eigh(k)
利用特征值分解:
-
torch.linalg.eigh()
返回 最小到最大 的特征值和特征向量。 -
我们只要 最后一个特征向量(对应最大特征值)。
提取最终四元数
return vectors[..., -1]
返回最后一个特征向量,即我们最终求出的四元数 (a,b,c,d)。
关键总结
1️⃣ 矩阵 K 是核心,它是从旋转矩阵推导回四元数的桥梁,最大特征值的特征向量就是四元数。
2️⃣ 代码巧妙地构建了矩阵 K,并且用了 torch.linalg.eigh()
直接提取最大特征向量,避免了复杂的符号判断和条件分支。
3️⃣ 优雅高效 🎯!比传统的行列式法或者逐项推导更稳定、更易实现。
理论基础:
1. 从旋转矩阵到四元数的目标
2. 四元数与旋转矩阵的关系
3. 逆推的核心理论