网站设计技术有哪些?网站域名解析
AlphaFold3 input_pipeline 模块 process_tensors_from_config
函数用于根据配置(common_cfg
, mode_cfg
) 处理输入 tensors
,并应用不同的变换,最终返回处理后的数据。它包含非集成(nonensembled) 和 集成(ensembled) 两种变换策略,适用于不同的模型训练或推理模式。
源代码:
def process_tensors_from_config(tensors, common_cfg, mode_cfg):"""Based on the config, apply filters and transformations to the data."""ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)def wrap_ensemble_fn(data, i):"""Function to be mapped over the ensemble dimension."""d = data.copy()fns = ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed,)fn = compose(fns)d["ensemble_index"] = ireturn fn(d)no_templates = Trueif "template_aatype" in tensors:no_templates = tensors["template_aatype"].shape[0] == 0nonensembled = nonensembled_transform_fns(common_cfg,mode_cfg,)tensors = compose(nonensembled)(tensors)if "no_recycling_iters" in tensors:num_recycling = int(tensors["no_recycling_iters"])else:num_recycling = common_cfg.max_recycling_iterstensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1))return tensors@data_transforms.curry1
def compose(x, fs):for f in fs:x = f(x)return xdef map_fn(fun, x):ensembles = [fun(elem) for elem in x]features = ensembles[0].keys()ensembled_dict = {}for feat in features:ensembled_dict[feat] = torch.stack([dict_i[feat] for dict_i in ensembles], dim=-1)return ensembled_dict
源码解读:
1. 函数签名
def process_tensors_from_config(tensors, common_cfg, mode_cfg):"""Based on the config, apply filters and trans