PyTorch 之 torch.distributions.Categorical 详解
PyTorch 之 torch.distributions.Categorical 详解
- PyTorch 之 torch.distributions.Categorical 详解
- 一、创建分类分布
- (一)基本语法
- (二)示例
- 二、采样
- (一)方法
- (二)示例
- 三、计算概率
- (一)方法
- (二)示例
- 四、计算对数概率
- (一)方法
- (二)示例
- 五、其他方法
- (一)计算熵
- (二)枚举支持集
- (三)获取均值和方差
- 六、实际应用场景
- (一)强化学习中的策略选择
- (二)自然语言处理中的单词预测
PyTorch 之 torch.distributions.Categorical 详解
在深度学习的诸多任务中,我们常常需要处理离散概率分布,比如在自然语言处理中对词汇表中的单词进行采样,或者在强化学习中从策略网络输出的动作概率分布中选择动作。PyTorch 提供了 torch.distributions.Categorical
类,方便我们高效地创建和操作离散分类分布。本文将深入讲解这个类的用法,帮助你在实际项目中更好地利用它。
一、创建分类分布
(一)基本语法
torch.distributions.Categorical
的基本语法是:
torch.distributions.Categorical(probs=None, logits=None)
其中:
probs
:一个张量,表示每个类别的概率。它的值应该非负,并且所有元素的和为 1。例如,torch.tensor([0.1, 0.2, 0.3, 0.4])
表示有四个类别,它们的概率分别是 0.1、0.2、0.3 和 0.4。logits
:一个张量,表示每个类别的未归一化对数概率。系统会自动将其转换为概率值。比如,torch.tensor([1.0, 2.0, 3.0, 4.0])
会被处理成相应的概率分布。
(二)示例
import torch# 使用 probs 参数
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
categorical_dist = torch.distributions.Categorical(probs)# 使用 logits 参数
logits = torch.tensor([1.0, 2.0, 3.0, 4.0])
categorical_dist_logits = torch.distributions.Categorical(logits=logits)
二、采样
(一)方法
使用 sample
方法进行采样,语法是:
sample(sample_shape=torch.Size())
其中,sample_shape
是一个元组,用于指定采样的样本数量和形状。默认为空,表示采样一个样本。
(二)示例
# 采样一个样本
sample = categorical_dist.sample()
print(sample) # 输出一个类别索引,比如 tensor(3)# 采样多个样本
samples = categorical_dist.sample((5,))
print(samples) # 输出一个形状为 [5] 的张量,包含 5 个类别索引,如 tensor([2, 0, 3, 1, 3])
三、计算概率
(一)方法
借助 prob
方法计算概率,语法如下:
prob(value)
这里,value
是一个张量,表示类别索引,取值范围为 [0, num_categories - 1]
。
(二)示例
# 计算单个值的概率
prob_value = categorical_dist.prob(torch.tensor(2))
print(prob_value) # 输出类别索引为 2 的概率值,如 tensor(0.3)# 计算多个值的概率
prob_values = categorical_dist.prob(torch.tensor([0, 1, 2, 3]))
print(prob_values) # 输出一个形状为 [4] 的张量,包含每个类别索引对应的概率值,如 tensor([0.1, 0.2, 0.3, 0.4])
四、计算对数概率
(一)方法
调用 log_prob
方法计算对数概率,语法是:
log_prob(value)
参数 value
的含义和 prob
方法中的相同。
(二)示例
# 计算单个值的对数概率
log_prob_value = categorical_dist.log_prob(torch.tensor(1))
print(log_prob_value) # 输出类别索引为 1 的对数概率值,比如 tensor(-1.6094)# 计算多个值的对数概率
log_prob_values = categorical_dist.log_prob(torch.tensor([0, 1, 2, 3]))
print(log_prob_values) # 输出一个形状为 [4] 的张量,包含每个类别索引对应的对数概率值,如 tensor([-2.3026, -1.6094, -1.2039, -0.9163])
五、其他方法
(一)计算熵
使用 entropy
方法计算分类分布的熵,熵反映了分布的不确定性。值越大,表示不确定性越高。示例代码如下:
# 计算熵
entropy_value = categorical_dist.entropy()
print(entropy_value) # 输出一个值,如 tensor(1.3777)
(二)枚举支持集
通过 enumerate_support
方法枚举分布的支持集,即所有可能的类别索引。示例如下:
# 枚举支持集
support = categorical_dist.enumerate_support()
print(support) # 输出一个形状为 [4] 的张量,如 tensor([0, 1, 2, 3])
(三)获取均值和方差
可以直接访问 mean
和 variance
属性,分别获取分布的均值和方差。示例:
# 获取均值和方差
mean_value = categorical_dist.mean
variance_value = categorical_dist.variance
print(mean_value, variance_value) # 输出类似 tensor(2.3000) tensor(1.8100)
六、实际应用场景
(一)强化学习中的策略选择
在强化学习里,策略网络常输出动作的概率分布。此时,可以利用 torch.distributions.Categorical
来创建这个分布,然后通过采样来选择动作。例如:
import torch
import torch.nn as nn# 假设策略网络的输出是 logits
class PolicyNetwork(nn.Module):def __init__(self):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(4, 128) # 输入状态维度为 4self.fc2 = nn.Linear(128, 2) # 输出动作 logits,假设有 2 个动作def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xpolicy_net = PolicyNetwork()
state = torch.tensor([1.0, 2.0, 3.0, 4.0]) # 当前状态
logits = policy_net(state)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample() # 采样得到动作
print(action) # 输出动作索引,如 tensor(1)
(二)自然语言处理中的单词预测
在语言模型中,模型会预测下一个单词的概率分布。使用 torch.distributions.Categorical
可以方便地处理这个分布,比如进行采样生成文本。例如:
import torch
import torch.nn as nn# 假设语言模型的输出是单词的概率
class LanguageModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):super(LanguageModel, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x):x = self.embedding(x)x, _ = self.lstm(x)x = self.fc(x[:, -1, :]) # 取最后一个时间步的输出预测下一个单词return xvocab_size = 10000 # 词汇表大小
embedding_dim = 128
hidden_dim = 256
lm = LanguageModel(vocab_size, embedding_dim, hidden_dim)
input_word_indices = torch.tensor([1, 2, 3, 4]) # 输入单词的索引序列
probs = lm(input_word_indices)
word_dist = torch.distributions.Categorical(probs=probs)
next_word = word_dist.sample() # 采样得到下一个单词的索引
print(next_word) # 输出单词索引,如 tensor(125)
torch.distributions.Categorical
类是 PyTorch 中处理离散概率分布的有力工具。它丰富的功能使得在涉及到分类数据的概率操作时变得简单高效。掌握这个类的用法,能让你在强化学习、自然语言处理等诸多领域更加得心应手地构建和训练模型。建议你在实际项目中多加练习,深入理解其原理和应用场景。