当前位置: 首页 > news >正文

AF3 shaped_categorical函数解读

AlphaFold3  data_transforms 模块的 shaped_categorical 函数的作用是根据给定的概率分布 probs 生成随机样本。这个函数主要用于在掩码语言模型(Masked Language Model,MLM)中对 MSA 数据进行随机替换。

源代码:

def shaped_categorical(probs, epsilon=1e-10):
    ds = probs.shape
    num_classes = ds[-1]
    distribution = torch.distributions.categorical.Categorical(
        torch.reshape(probs + epsilon, [-1, num_classes])
    )
    counts = distribution.sample()
    return torch.reshape(counts, ds[:-1])

源码解读:

函数签名
def shaped_categorical(probs, epsilon=1e-10):
  • probs:输入的概率分布张量,形状为 [batch_size, sequence_length, num_classes]

  • epsilon:一个小的常数,用于避免概率为零的情况,确保数值稳定性。默认值为 1e-10

1. 获取概率分布的形状和类别数
ds = probs.shape
num_classes = ds[-1]
  • 获取输入概率分布 probs 的形状 ds

  • 获取最后一维的大小 num_classes,即类别数。

2. 创建分类分布对象
distribution = torch.distributions.categorical.Categorical(
    torch.reshape(probs + epsilon, [-1, num_classes])
)
  • 将概率分布 probs 展平为二维张量 [batch_size * sequence_length, num_classes]

  • 在概率分布中加入一个小的常数 epsilon,以避免概率为零的情况。

  • 使用 torch.distributions.categorical.Categorical 创建一个分类分布对象 distribution

3. 采样并恢复原始形状
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
  • 使用 distribution.sample() 从分类分布中生成随机样本。生成的样本形状为 [batch_size * sequence_length]

  • 将生成的样本恢复为原始形状 [batch_size, sequence_length]

  • ds[:-1] 表示取 ds 中从第一个元素到最后一个元素之前的切片,即去掉最后一个元素。

函数输出
  • 函数返回一个形状为 [batch_size, sequence_length] 的张量,其中每个位置的值是从对应的概率分布中采样得到的类别索引。


总结

  • shaped_categorical 函数的作用:根据给定的概率分布生成随机样本,并确保样本的形状与输入概率分布的形状一致。

  • 核心逻辑

    1. 获取概率分布的形状和类别数。

    2. 将概率分布展平为二维张量,并加入一个小的常数以避免概率为零。

    3. 使用 torch.distributions.categorical.Categorical 创建分类分布对象。

    4. 从分类分布中生成随机样本,并将样本恢复为原始形状。

  • 输出:一个形状为 [batch_size, sequence_length] 的张量,表示从概率分布中采样得到的类别索引。

这个函数在处理掩码语言模型(MLM)任务时非常关键,它用于生成掩码位置的随机替换值。

相关文章:

  • 大数据hadoop课程笔记
  • HTTPS协议原理:在Linux世界里的加密冒险
  • Jupyter Notebook 全平台安装与配置教程(附Python/Anaconda双方案)
  • Spring(3)—— 获取http头部信息
  • 如何创建一个Vue项目
  • 在Visual Studio 2022中实现Qt插件开发
  • 低版本 Linux 系统通过二进制方式升级部署高版本 Docker
  • Win7 火狐浏览器 Mozilla Firefox 115.7.0esr下载地址(及Chrome、Supermium浏览器)
  • Session、Cookie、Token的区别
  • OceanBase社区年度之星专访:张稚京与OB社区的双向奔赴
  • 算法手记1
  • 基于Java的面向对象的多态示例
  • Maven指定JDK
  • function call为大模型装上触手
  • Java中的分布式锁:原理、实现与最佳实践
  • webpack介绍
  • Android Compose Surface 完全指南:从入门到花式操作
  • 四种常见的 API 架构风格(带示例)
  • vue2中,在table单元格上右键,对行、列的增删操作(模拟wps里的表格交互)
  • 无人机全景应用解析与技术演进趋势
  • 川大全职引进考古学家宫本一夫,他曾任日本九州大学副校长
  • 晒被子最大的好处,其实不是杀螨虫,而是……
  • 科普|治疗腰椎间盘突出症,筋骨平衡理论如何提供新视角?
  • 库里22分赢下抢七大战,火箭10年难破“火勇大战”的魔咒
  • 深入景区、文化街区及消费一线,多地省委书记调研文旅市场
  • 张家口一景区观光魔毯疑失控致游客被甩出,涉事景区改造升级重新开园才3天