当前位置: 首页 > news >正文

torch.distributions.categorical.Categorical 介绍

torch.distributions.categorical.Categorical 是 PyTorch 提供的离散概率分布(Categorical Distribution)类,用于从类别型概率分布(Categorical Distribution)中采样随机变量。

1. 语法

torch.distributions.categorical.Categorical(probs=None, logits=None)

2. 参数

参数作用
probs概率分布,形状为 [batch_size, num_classes],其中每行的值应为非负数,且每行的总和为 1.0
logits也可以用 logits 代替 probs,即未归一化的分数(softmax 之前的值),PyTorch 会自动计算 softmax 归一化

⚠ 注意probs 和 logits 只能二选一,否则会报错。

3. 基本用法

(1)用 probs 采样
import torch

# 定义类别概率(3 个类别)
probs = torch.tensor([0.1, 0.3, 0.6])  # 类别 0, 1, 2 的概率分别是 10%, 30%, 60%

# 创建 Categorical 分布
dist = torch.distributions.categorical.Categorical(probs)

# 采样一个类别
sample = dist.sample()
print(sample)  # 输出可能是 0, 1, 或 2,概率分别为 10%, 30%, 60%
(2)用 logits 采样
logits = torch.tensor([1.0, 2.0, 3.0])  # 未归一化的 logits
dist = torch.distributions.categorical.Categorical(logits=logits)

sample = dist.sample()
print(sample)  # 输出 0, 1, 2 的概率由 softmax(logits) 决定

内部计算方式

probs = torch.nn.functional.softmax(logits, dim=-1)

所以,logits [1.0, 2.0, 3.0] 会被转换为:

probs = torch.tensor([0.0900, 0.2447, 0.6652])  # softmax 归一化后的概率

4. 批量采样

如果 probs 是一个二维 Tensor,则可以对多个分布进行批量采样:

probs = torch.tensor([[0.2, 0.8], [0.5, 0.5], [0.9, 0.1]])  # 3 组分布
dist = torch.distributions.categorical.Categorical(probs)

samples = dist.sample()
print(samples)  # 每个样本是 0 或 1(按不同的行概率分布)

解释

  • probs[0] = [0.2, 0.8],第 1 个分布中 1 的概率是 80%
  • probs[1] = [0.5, 0.5],第 2 个分布是均匀分布
  • probs[2] = [0.9, 0.1],第 3 个分布中 0 的概率是 90%

5. 计算 log_prob(计算样本的对数概率)

可以计算某个类别出现的 对数概率

probs = torch.tensor([0.1, 0.3, 0.6])
dist = torch.distributions.categorical.Categorical(probs)

log_prob = dist.log_prob(torch.tensor(2))  # 计算类别 2 的对数概率
print(log_prob)  # 输出: -0.5108 (即 log(0.6))

log_prob 的作用:

  • log_prob(x) 计算 类别 x 出现的 log 概率,即 log(P(x))
  • 常用于计算损失函数(如交叉熵)

总结

torch.distributions.categorical.Categorical 是 PyTorch 中用于处理离散分类分布的工具。它支持从分布中采样、计算对数概率和熵,并且可以处理多维输入。在自然语言处理、掩码语言模型和强化学习等任务中,分类分布是一个非常重要的工具。

相关文章:

  • 智慧停车小程序:实时车位查询、导航与费用结算一体化
  • Datawhale AI + 办公 笔记2
  • linux自启动服务
  • 使用 Tesseract 进行 OCR 识别的详细指南
  • Linux开发工具----vim
  • Room数据库的使用
  • STM32Cubemx-H7-7-OLED屏幕
  • 【Python】【数据分析】Python 数据分析与可视化:全面指南
  • 【Python 2D绘图】Matplotlib绘图(统计图表)
  • 【冯诺依曼:到底有什么重大贡献 关键字摘抄】
  • ngx_conf_param
  • JAVA面试_进阶部分_java中四种引用类型(对象的强、软、弱和虚引用)
  • 开发中常见状态码以及状态码用途
  • Mysql8.x常用命令
  • XXE 目录
  • 从零开发Chrome广告拦截插件:开发、打包到发布全攻略
  • 企业网设计
  • 【数据库】10分钟学会MySQL的增删改查:数据库、表、表记录操作指南
  • 数字电路逻辑代数 | 运算 / 定律 / 公式 / 规则 / 例解
  • MySQL创建数据库和表,插入四大名著中的人物
  • 司法部:建立行政执法监督企业联系点,推行行政执法监督员制度
  • 巴基斯坦军方:印度袭击已致巴方31人死亡
  • 成立6天的公司拍得江西第三大水库20年承包经营权,当地回应
  • 马上评|演出服“穿过就退货”的闹剧不该一再重演
  • 默茨在德国联邦议院第一轮投票中未能当选总理
  • 甘肃省政府原党组成员、副省长杨子兴被提起公诉