AF3 correct_msa_restypes函数解读
AlphaFold3 data_transforms 模块的
correct_msa_restypes 函数 该函数的作用是 将 MSA(多序列比对)的氨基酸索引转换为 AlphaFold3 预期的索引顺序,确保 MSA 数据的氨基酸类型与 rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
兼容,同时 转换 MSA 相关的 profile 数据。
源代码:
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
[new_order_list] * protein["msa"].shape[1],
device=protein["msa"].device,
).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
for k in protein:
if "profile" in k:
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [
20,
21,
22,
], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein
源码解读:
该函数的输入 protein
是一个包含 MSA 相关信息的字典,主要处理 "msa"
和 "profile"
相关的键。
1️⃣ 重新映射 MSA 的氨基酸索引
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
是一个