AF3 OpenFoldSingleMultimerDataset类解读
AlphaFold3 data_modules 模块的 OpenFoldSingleMultimerDataset类继承自torch.utils.data.Dataset,用于处理多肽复合物(Multimer) 数据,提供数据加载、特征提取、模板匹配等功能。该类支持 训练(train)、验证(eval)、推理(predict) 三种模式,并在__getitem__ 方法中读取 .cif 结构文件、MSA 数据、模板特征,最终返回 标准化的蛋白质特征张量。
源代码:
class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
mmcif_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
"""
This class check each individual PDB ID and return its chain(s) features/ground truth
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
mmcif_data_cache_path:
Path to cache of all mmcifs files generated by
scripts/generate_mmcif_cache.py It should be a json file which records
what PDB ID contains which chain(s)
kalign_binary_path:
Path to kalign binary.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir
self.mmcif_data_cache_path = mmcif_data_cache_path
self.mmcif_data_cache = None
if self.mmcif_data_cache_path is not None:
with open(self.mmcif_data_cache_path, "r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw
self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}')
if template_release_dates_cache_path is None:
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
elif self.alignment_index is not None:
self._mmcifs = [i.split("_")[0] for i in list(alignment_index.keys())]
elif self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
if filter_path is not None:
with open(filter_path, "r") as f:
mmcifs_to_include = set([l.strip() for l in f.readlines()])
self._mmcifs = [
m for m in self._mmcifs if m in mmcifs_to_include
]
self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs)
}
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir,