当前位置: 首页 > news >正文

PyTorch 之 torch.distributions.Categorical 详解

PyTorch 之 torch.distributions.Categorical 详解

  • PyTorch 之 torch.distributions.Categorical 详解
    • 一、创建分类分布
      • (一)基本语法
      • (二)示例
    • 二、采样
      • (一)方法
      • (二)示例
    • 三、计算概率
      • (一)方法
      • (二)示例
    • 四、计算对数概率
      • (一)方法
      • (二)示例
    • 五、其他方法
      • (一)计算熵
      • (二)枚举支持集
      • (三)获取均值和方差
    • 六、实际应用场景
      • (一)强化学习中的策略选择
      • (二)自然语言处理中的单词预测

PyTorch 之 torch.distributions.Categorical 详解

在深度学习的诸多任务中,我们常常需要处理离散概率分布,比如在自然语言处理中对词汇表中的单词进行采样,或者在强化学习中从策略网络输出的动作概率分布中选择动作。PyTorch 提供了 torch.distributions.Categorical 类,方便我们高效地创建和操作离散分类分布。本文将深入讲解这个类的用法,帮助你在实际项目中更好地利用它。

一、创建分类分布

(一)基本语法

torch.distributions.Categorical 的基本语法是:

torch.distributions.Categorical(probs=None, logits=None)

其中:

  • probs:一个张量,表示每个类别的概率。它的值应该非负,并且所有元素的和为 1。例如,torch.tensor([0.1, 0.2, 0.3, 0.4]) 表示有四个类别,它们的概率分别是 0.1、0.2、0.3 和 0.4。
  • logits:一个张量,表示每个类别的未归一化对数概率。系统会自动将其转换为概率值。比如,torch.tensor([1.0, 2.0, 3.0, 4.0]) 会被处理成相应的概率分布。

(二)示例

import torch# 使用 probs 参数
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
categorical_dist = torch.distributions.Categorical(probs)# 使用 logits 参数
logits = torch.tensor([1.0, 2.0, 3.0, 4.0])
categorical_dist_logits = torch.distributions.Categorical(logits=logits)

二、采样

(一)方法

使用 sample 方法进行采样,语法是:

sample(sample_shape=torch.Size())

其中,sample_shape 是一个元组,用于指定采样的样本数量和形状。默认为空,表示采样一个样本。

(二)示例

# 采样一个样本
sample = categorical_dist.sample()
print(sample)  # 输出一个类别索引,比如 tensor(3)# 采样多个样本
samples = categorical_dist.sample((5,))
print(samples)  # 输出一个形状为 [5] 的张量,包含 5 个类别索引,如 tensor([2, 0, 3, 1, 3])

三、计算概率

(一)方法

借助 prob 方法计算概率,语法如下:

prob(value)

这里,value 是一个张量,表示类别索引,取值范围为 [0, num_categories - 1]

(二)示例

# 计算单个值的概率
prob_value = categorical_dist.prob(torch.tensor(2))
print(prob_value)  # 输出类别索引为 2 的概率值,如 tensor(0.3)# 计算多个值的概率
prob_values = categorical_dist.prob(torch.tensor([0, 1, 2, 3]))
print(prob_values)  # 输出一个形状为 [4] 的张量,包含每个类别索引对应的概率值,如 tensor([0.1, 0.2, 0.3, 0.4])

四、计算对数概率

(一)方法

调用 log_prob 方法计算对数概率,语法是:

log_prob(value)

参数 value 的含义和 prob 方法中的相同。

(二)示例

# 计算单个值的对数概率
log_prob_value = categorical_dist.log_prob(torch.tensor(1))
print(log_prob_value)  # 输出类别索引为 1 的对数概率值,比如 tensor(-1.6094)# 计算多个值的对数概率
log_prob_values = categorical_dist.log_prob(torch.tensor([0, 1, 2, 3]))
print(log_prob_values)  # 输出一个形状为 [4] 的张量,包含每个类别索引对应的对数概率值,如 tensor([-2.3026, -1.6094, -1.2039, -0.9163])

五、其他方法

(一)计算熵

使用 entropy 方法计算分类分布的熵,熵反映了分布的不确定性。值越大,表示不确定性越高。示例代码如下:

# 计算熵
entropy_value = categorical_dist.entropy()
print(entropy_value)  # 输出一个值,如 tensor(1.3777)

(二)枚举支持集

通过 enumerate_support 方法枚举分布的支持集,即所有可能的类别索引。示例如下:

# 枚举支持集
support = categorical_dist.enumerate_support()
print(support)  # 输出一个形状为 [4] 的张量,如 tensor([0, 1, 2, 3])

(三)获取均值和方差

可以直接访问 meanvariance 属性,分别获取分布的均值和方差。示例:

# 获取均值和方差
mean_value = categorical_dist.mean
variance_value = categorical_dist.variance
print(mean_value, variance_value)  # 输出类似 tensor(2.3000) tensor(1.8100)

六、实际应用场景

(一)强化学习中的策略选择

在强化学习里,策略网络常输出动作的概率分布。此时,可以利用 torch.distributions.Categorical 来创建这个分布,然后通过采样来选择动作。例如:

import torch
import torch.nn as nn# 假设策略网络的输出是 logits
class PolicyNetwork(nn.Module):def __init__(self):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(4, 128)  # 输入状态维度为 4self.fc2 = nn.Linear(128, 2)  # 输出动作 logits,假设有 2 个动作def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xpolicy_net = PolicyNetwork()
state = torch.tensor([1.0, 2.0, 3.0, 4.0])  # 当前状态
logits = policy_net(state)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()  # 采样得到动作
print(action)  # 输出动作索引,如 tensor(1)

(二)自然语言处理中的单词预测

在语言模型中,模型会预测下一个单词的概率分布。使用 torch.distributions.Categorical 可以方便地处理这个分布,比如进行采样生成文本。例如:

import torch
import torch.nn as nn# 假设语言模型的输出是单词的概率
class LanguageModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):super(LanguageModel, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x):x = self.embedding(x)x, _ = self.lstm(x)x = self.fc(x[:, -1, :])  # 取最后一个时间步的输出预测下一个单词return xvocab_size = 10000  # 词汇表大小
embedding_dim = 128
hidden_dim = 256
lm = LanguageModel(vocab_size, embedding_dim, hidden_dim)
input_word_indices = torch.tensor([1, 2, 3, 4])  # 输入单词的索引序列
probs = lm(input_word_indices)
word_dist = torch.distributions.Categorical(probs=probs)
next_word = word_dist.sample()  # 采样得到下一个单词的索引
print(next_word)  # 输出单词索引,如 tensor(125)

torch.distributions.Categorical 类是 PyTorch 中处理离散概率分布的有力工具。它丰富的功能使得在涉及到分类数据的概率操作时变得简单高效。掌握这个类的用法,能让你在强化学习、自然语言处理等诸多领域更加得心应手地构建和训练模型。建议你在实际项目中多加练习,深入理解其原理和应用场景。

相关文章:

  • MATLAB中进行语音信号分析
  • USB学习【13】STM32+USB接收数据过程详解
  • 关于element-ui的table type=“expand“ 嵌套表格展开异常问题解决方案
  • CYT4BB Dual Bank 1 - 存储机制
  • 02 基本介绍及Pod基础排错
  • P/Invoke 内存资源处理方案
  • Linux bash shell的循环命令for、while和until
  • C++面向对象——多态
  • 单片机复用功能重映射Remap功能
  • 基于单片机的车辆防盗系统设计与实现
  • 第六部分:第三节 - 路由与请求处理:解析顾客的点单细节
  • 【基础知识】SPI协议的种类及异同
  • OpenCV CUDA 模块特征检测与描述------在GPU上执行特征描述符匹配的类cv::cuda::DescriptorMatcher
  • SetThrowSegvLongjmpSEHFilter错误和myFuncInitialize 崩溃
  • 宝塔+fastadmin:给项目添加定时任务
  • 汽车区域电子电气架构(Zonal E/E)的统一
  • CentOS 7上BIND9配置DNS服务器指南
  • SpringSecurity基础入门
  • 使用Mathematica绘制一类矩阵的特征值图像
  • 使用F5-tts复刻音色
  • 上海一隧道成“王家卫风”网红拍照点?交管部门已专项整治,一人被处罚
  • 科学与艺术的跨界对话可能吗?——评“以蚁为序的生命网络”
  • 来论|以法治之力激发民营经济新动能
  • 证监会副主席李明:支持符合条件的外资机构申请新业务、设立新产品
  • 俄乌刚谈完美国便筹划与两国领导人通话,目的几何?
  • 英国知名歌手批政府:让AI公司免费使用艺术家作品是盗窃