AF3 ProteinDataset类的_process方法解读
AlphaFold3 protein_dataset模块 ProteinDataset
类 _process
方法的主要功能是处理单个蛋白质流文件,并将其转换为 ProteinMPNN 所需的特征,最终保存为 pickle
文件。
源代码:
def _process(
self,
filename,
rewrite=False,
max_length=None,
min_cdr_length=None,
classes_to_exclude=None,
):
"""Process a proteinflow file and save it as ProteinMPNN features."""
input_file = os.path.join(self.dataset_folder, filename)
no_extension_name = filename.split(".")[0]
data_entry = ProteinEntry.from_pickle(input_file)
if self.load_ligands:
ligands = ProteinEntry.retrieve_ligands_from_pickle(input_file)
if classes_to_exclude is not None:
if data_entry.get_protein_class() in classes_to_exclude:
return []
chains = data_entry.get_chains()
if self.entry_type == "biounit":
chain_sets = [chains]
elif self.entry_type == "chain":
chain_sets = [[x] for x in chains]
elif self.entry_type == "pair":
if len(chains) == 1:
return []
chain_sets = list(combinations(chains, 2))
else:
raise RuntimeError(
"Unknown entry type, please choose from ['biounit', 'chain', 'pair']"
)
output_names = []
if self.cut_edges:
data_entry.cut_missing_edges()
for chains_i, chain_set in enumerate(chain_sets):
output_file = os.path.join(
self.features_folder, no_extension_name + f"_{chains_i}.pickle"
)
pass_set = False
add_name = True
if os.path.exists(output_file) and not rewrite:
pass_set = True
if max_length is not None:
if data_entry.get_length(chain_set) > max_length:
add_name = False
if min_cdr_length is not None and data_entry.has_cdr():
cdr_length = data_entry.get_cdr_length(chain_set)
if not all(
[
length >= min_cdr_length
for length in cdr_length.values()
if length > 0
]
):
add_name = False
else:
if max_length is not None:
if data_entry.get_length(chains=chain_set) > max_length:
pass_set = True
add_name = False
if min_cdr_length is not None and data_entry.has_cdr():
cdr_length = data_entry.get_cdr_length(chain_set)
if not all(
[
length >= min_cdr_length
for length in cdr_length.values()
if length > 0
]
):
pass_set = True
add_name = False
if self.entry_type == "pair":
if not data_entry.is_valid_pair(*chain_set):
pass_set = True
add_name = False
out = {}
if add_name:
cdr_chain_set = set()
if data_entry.has_cdr():
out["cdr"] = torch.tensor(
data_entry.get_cdr(chain_set, encode=True)
)
chain_type_dict = data_entr