当前位置: 首页 > news >正文

AF3 OpenFoldDataset类解读

AlphaFold3 data_modules 模块的 OpenFoldDataset 类是一个自定义的数据集类,继承自 torch.utils.data.Dataset。它的目的是在训练时实现 随机过滤器(stochastic filters),用于从多个不同的数据集(OpenFoldSingleDataset 或 OpenFoldSingleMultimerDataset)中进行样本选择和过滤。该类具有一些 数据增强 和 过滤 的机制,以提高模型的训练质量。

源代码:

class OpenFoldDataset(torch.utils.data.Dataset):
    """
        Implements the stochastic filters applied during AlphaFold's training.
        Because samples are selected from constituent datasets randomly, the
        length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
        and filtered once at initialization.
    """

    def __init__(self,
                 datasets: Union[Sequence[OpenFoldSingleDataset], Sequence[OpenFoldSingleMultimerDataset]],
                 probabilities: Sequence[float],
                 epoch_len: int,
                 generator: torch.Generator = None,
                 _roll_at_init: bool = True,
                 ):
        self.datasets = datasets
        self.probabilities = probabilities
        self.epoch_len = epoch_len
        self.generator = generator

        self.datapoints = None

        self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
        if _roll_at_init:
            self.reroll()

    @staticmethod
    def deterministic_train_filter(
        cache_entry: Any,
        max_resolution: float = 9.,
        max_single_aa_prop: float = 0.8,
        *args, **kwargs
    ) -> bool:
        # Hard filters
        resolution = cache_entry.get("resolution", None)
        seqs = [cache_entry["seq"]]

        return all([resolution_filter(resolution=resolution,
                                      max_resolution=max_resolution),
                    aa_count_filter(seqs=seqs,
                                    max_single_aa_prop=max_single_aa_prop)])

    @staticmethod
    def get_stochastic_train_filter_prob(
        cache_entry: Any,
        *args, **kwargs
    ) -> float:
        # Stochastic filters
        probabilities = []

        cluster_size = cache_entry.get("cluster_size", None)
        if cluster_size is not None and cluster_size > 0:
            probabilities.append(1 / cluster_size)

        chain_length = len(cache_entry["seq"])
 
http://www.dtcms.com/a/109314.html

相关文章:

  • 【面试篇】Kafka
  • 记录学习的第二十天
  • 【LeetCode 题解】数据库:626.换座位
  • Java基础:Logback日志框架
  • C# 与 相机连接
  • 接收灵敏度的基本概念与技术解析
  • 【计网】作业三
  • 2025年2月,美国发布了新版移动灯的安规标准:UL153标准如何办理?
  • MySQL:库表操作
  • CATIA装配体全自动存储解决方案开发实战——基于递归算法的产品结构树批量处理技术
  • 一款非常小的软件,操作起来非常丝滑!
  • 语音识别播报人工智能分类垃圾桶(论文+源码)
  • MySQL 基础使用指南-MySQL登录与远程登录
  • MySQL超全笔记
  • 快速掌握MCP——Spring AI MCP包教包会
  • Pyspark学习二:快速入门基本数据结构
  • 4月3号.
  • Python 函数知识梳理与经典编程题解析
  • FFmpeg录制屏幕和音频
  • 单片机学习之定时器
  • 嵌入式海思Hi3861连接华为物联网平台操作方法
  • Zapier MCP:重塑跨应用自动化协作的技术实践
  • 【Linux】Orin NX + Ubuntu22.04配置国内源
  • 如何实现一个优雅的Go协程池
  • ORION:基于VLM引导动作生成的端到端框架——论文精度
  • 源码分析之Leaflet图层控制控件Control.Layers实现原理
  • 量子计算与人工智能的结合:未来科技的双重革命
  • 人工智能混合编程实践:C++ ONNX进行图像超分重建
  • 从零实现Json-Rpc框架】- 项目实现 - 服务端主题实现及整体封装
  • “清凉海岛·创享一夏” 海南启动旅游线路产品创意设计大赛