AF3 OpenFoldDataLoader类_add_batch_properties方法解读
AlphaFold3 data_modules 模块的OpenFoldDataLoader
类中的 _add_batch_properties
方法,它的功能是为每个数据批次(batch)添加与 "recycling"(重循环)相关的属性。具体来说,它通过对每个数据批次根据一定的概率分布进行采样来确定每个样本的循环次数,并将这些循环次数添加到批次数据中。
源代码:
def _add_batch_properties(self, batch):
# TODO: gt_features might change
gt_features = batch.pop('gt_features', None)
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
)
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tenso