当前位置: 首页 > news >正文

AF3 shaped_categorical函数解读

AlphaFold3  data_transforms 模块的 shaped_categorical 函数的作用是根据给定的概率分布 probs 生成随机样本。这个函数主要用于在掩码语言模型(Masked Language Model,MLM)中对 MSA 数据进行随机替换。

源代码:

def shaped_categorical(probs, epsilon=1e-10):
    ds = probs.shape
    num_classes = ds[-1]
    distribution = torch.distributions.categorical.Categorical(
        torch.reshape(probs + epsilon, [-1, num_classes])
    )
    counts = distribution.sample()
    return torch.reshape(counts, ds[:-1])

源码解读:

函数签名
def shaped_categorical(probs, epsilon=1e-10):
  • probs:输入的概率分布张量,形状为 [batch_size, sequence_length, num_classes]

  • epsilon:一个小的常数,用于避免概率为零的情况,确保数值稳定性。默认值为 1e-10

1. 获取概率分布的形状和类别数
ds = probs.shape
num_classes = ds[-1]
  • 获取输入概率分布 probs 的形状 ds

  • 获取最后一维的大小 num_classes,即类别数。

2. 创建分类分布对象
distribution = torch.distributions.categorical.Categorical(
    torch.reshape(probs + epsilon, [-1, num_classes])
)
  • 将概率分布 probs 展平为二维张量 [batch_size * sequence_length, num_classes]

  • 在概率分布中加入一个小的常数 epsilon,以避免概率为零的情况。

  • 使用 torch.distributions.categorical.Categorical 创建一个分类分布对象 distribution

3. 采样并恢复原始形状
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
  • 使用 distribution.sample() 从分类分布中生成随机样本。生成的样本形状为 [batch_size * sequence_length]

  • 将生成的样本恢复为原始形状 [batch_size, sequence_length]

  • ds[:-1] 表示取 ds 中从第一个元素到最后一个元素之前的切片,即去掉最后一个元素。

函数输出
  • 函数返回一个形状为 [batch_size, sequence_length] 的张量,其中每个位置的值是从对应的概率分布中采样得到的类别索引。


总结

  • shaped_categorical 函数的作用:根据给定的概率分布生成随机样本,并确保样本的形状与输入概率分布的形状一致。

  • 核心逻辑

    1. 获取概率分布的形状和类别数。

    2. 将概率分布展平为二维张量,并加入一个小的常数以避免概率为零。

    3. 使用 torch.distributions.categorical.Categorical 创建分类分布对象。

    4. 从分类分布中生成随机样本,并将样本恢复为原始形状。

  • 输出:一个形状为 [batch_size, sequence_length] 的张量,表示从概率分布中采样得到的类别索引。

这个函数在处理掩码语言模型(MLM)任务时非常关键,它用于生成掩码位置的随机替换值。

http://www.dtcms.com/a/62801.html

相关文章:

  • 大数据hadoop课程笔记
  • HTTPS协议原理:在Linux世界里的加密冒险
  • Jupyter Notebook 全平台安装与配置教程(附Python/Anaconda双方案)
  • Spring(3)—— 获取http头部信息
  • 如何创建一个Vue项目
  • 在Visual Studio 2022中实现Qt插件开发
  • 低版本 Linux 系统通过二进制方式升级部署高版本 Docker
  • Win7 火狐浏览器 Mozilla Firefox 115.7.0esr下载地址(及Chrome、Supermium浏览器)
  • Session、Cookie、Token的区别
  • OceanBase社区年度之星专访:张稚京与OB社区的双向奔赴
  • 算法手记1
  • 基于Java的面向对象的多态示例
  • Maven指定JDK
  • function call为大模型装上触手
  • Java中的分布式锁:原理、实现与最佳实践
  • webpack介绍
  • Android Compose Surface 完全指南:从入门到花式操作
  • 四种常见的 API 架构风格(带示例)
  • vue2中,在table单元格上右键,对行、列的增删操作(模拟wps里的表格交互)
  • 无人机全景应用解析与技术演进趋势
  • AI开源竞赛与硬件革命:2025年3月科技热点全景解读——阿里、腾讯领跑开源,英特尔、台积电重塑算力格局
  • 考研数学复习之定积分定义求解数列极限(超详细教程)
  • HTML5教程之标签(7)
  • Java关键字与标识符
  • 基于6自由度搬运机器人完成单关节伺服控制实现的详细步骤及示例代码
  • 基于YOLO11深度学习的遥感视角地面房屋建筑检测分割与分析系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标分割、人工智能
  • 【GNN】第四章:图卷积层GCN
  • Linux 服务器安全配置:密码复杂度与登录超时设置
  • 缓存id路由页面返回,历史路由栈
  • SpringBoot基础Kafka示例