AF3 process_tensors_from_config函数解读
AlphaFold3 input_pipeline 模块 process_tensors_from_config
函数用于根据配置(common_cfg
, mode_cfg
) 处理输入 tensors
,并应用不同的变换,最终返回处理后的数据。它包含非集成(nonensembled) 和 集成(ensembled) 两种变换策略,适用于不同的模型训练或推理模式。
源代码:
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
no_templates = True
if "template_aatype" in tensors:
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if "no_recycling_iters" in tensors:
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
源码解读:
1. 函数签名
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and trans