torch.distributions.categorical.Categorical 介绍
torch.distributions.categorical.Categorical
是 PyTorch 提供的离散概率分布(Categorical Distribution)类,用于从类别型概率分布(Categorical Distribution)中采样随机变量。
1. 语法
torch.distributions.categorical.Categorical(probs=None, logits=None)
2. 参数
参数 | 作用 |
---|---|
probs | 概率分布,形状为 [batch_size, num_classes] ,其中每行的值应为非负数,且每行的总和为 1.0 |
logits | 也可以用 logits 代替 probs ,即未归一化的分数(softmax 之前的值),PyTorch 会自动计算 softmax 归一化 |
⚠ 注意:
probs
和logits
只能二选一,否则会报错。
3. 基本用法
(1)用 probs
采样
import torch
# 定义类别概率(3 个类别)
probs = torch.tensor([0.1, 0.3, 0.6]) # 类别 0, 1, 2 的概率分别是 10%, 30%, 60%
# 创建 Categorical 分布
dist = torch.distributions.categorical.Categorical(probs)
# 采样一个类别
sample = dist.sample()
print(sample) # 输出可能是 0, 1, 或 2,概率分别为 10%, 30%, 60%
(2)用 logits
采样
logits = torch.tensor([1.0, 2.0, 3.0]) # 未归一化的 logits
dist = torch.distributions.categorical.Categorical(logits=logits)
sample = dist.sample()
print(sample) # 输出 0, 1, 2 的概率由 softmax(logits) 决定
内部计算方式:
probs = torch.nn.functional.softmax(logits, dim=-1)
所以,logits [1.0, 2.0, 3.0]
会被转换为:
probs = torch.tensor([0.0900, 0.2447, 0.6652]) # softmax 归一化后的概率
4. 批量采样
如果 probs
是一个二维 Tensor,则可以对多个分布进行批量采样:
probs = torch.tensor([[0.2, 0.8], [0.5, 0.5], [0.9, 0.1]]) # 3 组分布
dist = torch.distributions.categorical.Categorical(probs)
samples = dist.sample()
print(samples) # 每个样本是 0 或 1(按不同的行概率分布)
解释:
probs[0] = [0.2, 0.8]
,第 1 个分布中1
的概率是80%
probs[1] = [0.5, 0.5]
,第 2 个分布是均匀分布probs[2] = [0.9, 0.1]
,第 3 个分布中0
的概率是90%
5. 计算 log_prob
(计算样本的对数概率)
可以计算某个类别出现的 对数概率:
probs = torch.tensor([0.1, 0.3, 0.6])
dist = torch.distributions.categorical.Categorical(probs)
log_prob = dist.log_prob(torch.tensor(2)) # 计算类别 2 的对数概率
print(log_prob) # 输出: -0.5108 (即 log(0.6))
log_prob
的作用:
log_prob(x)
计算 类别x
出现的 log 概率,即log(P(x))
- 常用于计算损失函数(如交叉熵)
总结
torch.distributions.categorical.Categorical
是 PyTorch 中用于处理离散分类分布的工具。它支持从分布中采样、计算对数概率和熵,并且可以处理多维输入。在自然语言处理、掩码语言模型和强化学习等任务中,分类分布是一个非常重要的工具。