AF3 ProteinDataset类的_get_masked_sequence方法解读
AlphaFold3 protein_dataset模块 ProteinDataset
类 _get_masked_sequence
方法属于作用是为需要预测的残基生成掩码。该掩码以二进制张量形式呈现,其中 1 代表需要预测的部分,0 代表其他部分。此方法会依据多个参数来选定要掩码的残基,这些参数包含 mask_whole_chains
、mask_frac
、lower_limit
、upper_limit
、mask_sequential
以及 force_binding_sites_frac
等。
源代码:
def _get_masked_sequence(
self,
data,
):
"""Get the mask for the residues that need to be predicted.
Depending on the parameters the residues are selected as follows:
- if `mask_whole_chains` is `True`, the whole chain is masked
- if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain,
- otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`].
If `mask_sequential` is `True`, the residues are masked based on the order in the sequence, otherwise a
spherical mask is applied based on the coordinates.
If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain
from a polymer is sampled, the center of the masked region will be forced to be in a binding site.
Parameters
----------
data : dict
an entry generated by `ProteinDataset`
Returns
-------
chain_M : torch.Tensor
a `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and
0 is everything else
"""
if "cdr" in data and "cdr_id" in data:
chain_M = torch.zeros_like(data["cdr"])
if self.mask_all_cdrs:
chain_M = data["cdr"] != CDR_REVERSE["-"]
else:
chain_M = data["cdr"] == data["cdr_id"]
else:
chain_M = torch.zeros_like(data["S"])
chain_index = data["chain_id"]
chain_bool = data["chain_encoding_all"] == chain_index
if self.mask_whole_chains:
chain_M[chain_bool] = 1
else:
chains = torch.unique(data["chain_encoding_all"])
chain_start = torch.where(chain_bool)[0][0]
chain = data["X"][chain_bool]
res_i = None
interface = []
non_masked_interface = []
if len(chains) > 1 and self.force_binding_sites_frac > 0:
if random.uniform(0, 1) <= self.force_binding_sites_frac:
X_copy = data["X"]
i_indices = (chain_bool == 0).nonzero().flatten() # global