嘉定企业网站开发站长之家seo信息
AlphaFold3 protein_dataset模块 ProteinDataset
类 _process
方法的主要功能是处理单个蛋白质流文件,并将其转换为 ProteinMPNN 所需的特征,最终保存为 pickle
文件。
源代码:
def _process(self,filename,rewrite=False,max_length=None,min_cdr_length=None,classes_to_exclude=None,):"""Process a proteinflow file and save it as ProteinMPNN features."""input_file = os.path.join(self.dataset_folder, filename)no_extension_name = filename.split(".")[0]data_entry = ProteinEntry.from_pickle(input_file)if self.load_ligands:ligands = ProteinEntry.retrieve_ligands_from_pickle(input_file)if classes_to_exclude is not None:if data_entry.get_protein_class() in classes_to_exclude:return []chains = data_entry.get_chains()if self.entry_type == "biounit":chain_sets = [chains]elif self.entry_type == "chain":chain_sets = [[x] for x in chains]elif self.entry_type == "pair":if len(chains) == 1:return []chain_sets = list(combinations(chains, 2))else:raise RuntimeError("Unknown entry type, please choose from ['biounit', 'chain', 'pair']")output_names = []if self.cut_edges:data_entry.cut_missing_edges()for chains_i, chain_set in enumerate(chain_sets):output_file = os.path.join(self.features_folder, no_extension_name + f"_{chains_i}.pickle")pass_set = Falseadd_name = Trueif os.path.exists(output_file) and not rewrite:pass_set = Trueif max_length is not None:if data_entry.get_length(chain_set) > max_length:add_name = Falseif min_cdr_length is not None and data_entry.has_cdr():cdr_length = data_entry.get_cdr_length(chain_set)if not all([length >= min_cdr_lengthfor length in cdr_length.values()if length > 0]):add_name = Falseelse:if max_length is not None:if data_entry.get_length(chains=chain_set) > max_length:pass_set = Trueadd_name = Falseif min_cdr_length is not None and data_entry.has_cdr():cdr_length = data_entry.get_cdr_length(chain_set)if not all([length >= min_cdr_lengthfor length in cdr_length.values()if length > 0]):pass_set = Trueadd_name = Falseif self.entry_type == "pair":if not data_entry.is_valid_pair(*chain_set):pass_set = Trueadd_name = Falseout = {}if add_name:cdr_chain_set = set()if data_entry.has_cdr():out["cdr"] = torch.tensor(data_entry.get_cdr(chain_set, encode=True))chain_type_dict = data_entr