AF3 _correct_post_merged_feats函数解读
AlphaFold3 msa_pairing 模块的 _correct_post_merged_feats
函数用于对合并后的特征进行修正,确保它们符合预期的格式和要求。这包括可能的对特征值进行调整或进一步的格式化,确保合并后的 FeatureDict
适合于后续模型的输入。
主要作用是:
- 在多链蛋白质 MSA(多序列比对)合并后,重新计算/调整某些特征:
seq_length
(序列长度)num_alignments
(MSA 比对的序列数)
- 为 MSA 生成合适的掩码(mask),用于模型训练:
cluster_bias_mask
:控制 MSA 的 query 序列位置。bert_mask
:用于 BERT-style MSA 预训练掩码。
源代码:
def _correct_post_merged_feats(
np_example: Mapping[str, np.ndarray],
np_chains_list: Sequence[Mapping[str, np.ndarray]],
pair_msa_sequences: bool
) -> Mapping[str, np.ndarray]:
"""Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(
np_example['aatype'].shape[0],
dtype=np.int32
)
np_example['num_alignments'] = np.asarray(
np_example['msa'].shape[0],
dtype=np.int32
)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list
]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0
)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list
]
msa_masks_all_seq = [
np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list
]