当前位置: 首页 > news >正文

AF3 ProteinDataset类的_get_masked_sequence方法解读

AlphaFold3 protein_dataset模块 ProteinDataset 类 _get_masked_sequence 方法属于作用是为需要预测的残基生成掩码。该掩码以二进制张量形式呈现,其中 1 代表需要预测的部分,0 代表其他部分。此方法会依据多个参数来选定要掩码的残基,这些参数包含 mask_whole_chainsmask_fraclower_limitupper_limitmask_sequential 以及 force_binding_sites_frac 等。

源代码:

    def _get_masked_sequence(
        self,
        data,
    ):
        """Get the mask for the residues that need to be predicted.

        Depending on the parameters the residues are selected as follows:
        - if `mask_whole_chains` is `True`, the whole chain is masked
        - if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain,
        - otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`].

        If `mask_sequential` is `True`, the residues are masked based on the order in the sequence, otherwise a
        spherical mask is applied based on the coordinates.

        If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain
        from a polymer is sampled, the center of the masked region will be forced to be in a binding site.

        Parameters
        ----------
        data : dict
            an entry generated by `ProteinDataset`

        Returns
        -------
        chain_M : torch.Tensor
            a `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and
            0 is everything else

        """
        if "cdr" in data and "cdr_id" in data:
            chain_M = torch.zeros_like(data["cdr"])
            if self.mask_all_cdrs:
                chain_M = data["cdr"] != CDR_REVERSE["-"]
            else:
                chain_M = data["cdr"] == data["cdr_id"]
        else:
            chain_M = torch.zeros_like(data["S"])
            chain_index = data["chain_id"]
            chain_bool = data["chain_encoding_all"] == chain_index

            if self.mask_whole_chains:
                chain_M[chain_bool] = 1
            else:
                chains = torch.unique(data["chain_encoding_all"])
                chain_start = torch.where(chain_bool)[0][0]
                chain = data["X"][chain_bool]
                res_i = None
                interface = []
                non_masked_interface = []
                if len(chains) > 1 and self.force_binding_sites_frac > 0:
                    if random.uniform(0, 1) <= self.force_binding_sites_frac:
                        X_copy = data["X"]

                        i_indices = (chain_bool == 0).nonzero().flatten()  # global
                 

相关文章:

  • Linux Kernel 1
  • gazebo 启动卡死的解决方法汇总
  • transformers的 pipeline是什么:将模型加载、数据预处理、推理等步骤进行了封装
  • Linux下Docker安装超详细教程(以CentOS为例)
  • transformer 规范化层
  • Linux 进程基础(一):冯诺依曼结构
  • Java设计模式实战:策略模式在SimUDuck问题中的应用
  • 使用Fortran读取HDF5数据
  • 若依前后端分离版运行教程、打包教程、部署教程
  • Linux-内核驱动
  • Window 10使用WSL2搭建Linux版Android Studio应用开发环境
  • Redis集群模式学习
  • Kubernetes nodeName Manual Scheduling practice (K8S节点名称绑定以及手工调度)
  • 【高性能缓存Redis_中间件】一、快速上手redis缓存中间件
  • 大型语言模型中的工具调用(Function Calling)技术详解
  • 鸿蒙开发05评论案例分析
  • 基于 Streamlit 的 PDF 编辑器
  • 1558 找素数
  • vue模拟扑克效果
  • AdamW 是 Adam 优化算法的改进版本; warmup_steps:学习率热身的步数
  • 最优网络做网站怎么样/域名查询官网
  • 嘉兴云推广网站/如何制作网页最简单的方法
  • 通州网站制作/吉林seo推广
  • 大连网站搜索优/长沙seo优化公司
  • 深圳集团网站建设服务/本地推广平台有哪些
  • 郑州做网站茂睿科技/网络销售的方法和技巧