当前位置: 首页 > news >正文

AF3 _correct_post_merged_feats函数解读

AlphaFold3 msa_pairing 模块的 _correct_post_merged_feats 函数用于对合并后的特征进行修正,确保它们符合预期的格式和要求。这包括可能的对特征值进行调整或进一步的格式化,确保合并后的 FeatureDict 适合于后续模型的输入。

主要作用是:

  1. 在多链蛋白质 MSA(多序列比对)合并后,重新计算/调整某些特征
    • seq_length(序列长度)
    • num_alignments(MSA 比对的序列数)
  2. 为 MSA 生成合适的掩码(mask),用于模型训练:
    • cluster_bias_mask:控制 MSA 的 query 序列位置。
    • bert_mask:用于 BERT-style MSA 预训练掩码。

源代码:

def _correct_post_merged_feats(
        np_example: Mapping[str, np.ndarray],
        np_chains_list: Sequence[Mapping[str, np.ndarray]],
        pair_msa_sequences: bool
) -> Mapping[str, np.ndarray]:
    """Adds features that need to be computed/recomputed post merging."""

    np_example['seq_length'] = np.asarray(
        np_example['aatype'].shape[0],
        dtype=np.int32
    )
    np_example['num_alignments'] = np.asarray(
        np_example['msa'].shape[0],
        dtype=np.int32
    )

    if not pair_msa_sequences:
        # Generate a bias that is 1 for the first row of every block in the
        # block diagonal MSA - i.e. make sure the cluster stack always includes
        # the query sequences for each chain (since the first row is the query
        # sequence).
        cluster_bias_masks = []
        for chain in np_chains_list:
            mask = np.zeros(chain['msa'].shape[0])
            mask[0] = 1
            cluster_bias_masks.append(mask)

        np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)

        # Initialize Bert mask with masked out off diagonals.
        msa_masks = [
            np.ones(x['msa'].shape, dtype=np.float32)
            for x in np_chains_list
        ]

        np_example['bert_mask'] = block_diag(
            *msa_masks, pad_value=0
        )
    else:
        np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
        np_example['cluster_bias_mask'][0] = 1

        # Initialize Bert mask with masked out off diagonals.
        msa_masks = [
            np.ones(x['msa'].shape, dtype=np.float32) for
            x in np_chains_list
        ]
        msa_masks_all_seq = [
            np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
            x in np_chains_list
        ]

    

相关文章:

  • 解决VSCode鼠标光标指针消失
  • 分布式锁实现方案对比与最佳实践
  • 【计网】数据链路层
  • Glide图片加载优化全攻略:从缓存到性能调优
  • python官方文档阅读整理(一)
  • 2024最新版Java面试题及答案,【来自于各大厂】
  • 【ORACLE】char类型和sql优化器发生的“错误”反应
  • 【工具推荐】在线提取PDF、文档、图片、论文中的公式
  • 数字万用表的使用教程
  • 学习 Wireshark 分析 Android Netlog
  • 什么是SElinux?
  • MongoDB Chunks核心概念与机制
  • 【前端】HTML 备忘清单(超级详细!)
  • 深入探索Python机器学习算法:模型调优
  • vue3,Element Plus中抽屉el-drawer的样式设置
  • 爬虫逆向实战小记——解决captcha滑动验证码
  • linux 安装Mysql无法远程访问问题的排查
  • JavaWeb5、Maven
  • 【安装】SQL Server 2005 安装及安装包
  • python-leetcode 47.路径总和III
  • 有没有专做推广小说的网站/seo黑帽教学网
  • 西安黄页网/百度官方优化软件
  • 有哪些做调查问卷赚钱的网站/百度搜索大全
  • 做网站编程时容易遇到的问题/seo搜索引擎推广什么意思
  • 网站开发的毕业设计题目/开户推广竞价开户
  • 网站建设公司兴田德润可信赖/做好网络推广