当前位置: 首页 > news >正文

[Pyro] 基础构件 | 随机性sample | 可学习参数param | 批量处理plate

链接:https://docs.pyro.ai/en/stable/


docs:Pyro

Pyro 是一个基于 PyTorch 构建的**深度概率编程语言**,支持用户构建、组合和学习复杂概率模型。其核心架构包含:

  • 基础构件:定义随机变量与可学习参数
  • 概率分布:描述不确定性
  • 效果处理器(Poutine):动态调整程序行为
  • 推理算法:整合SVI与MCMC等优化方法
  • 参数存储库:集中管理可学习参数
  • PyroModule:实现神经网络与概率模型的无缝集成

架构

在这里插入图片描述

章节导航

  1. 基础构件
  2. 概率分布
  3. 参数存储库(ParamStore)
  4. 效果处理器(Poutine)
  5. 优化器(Optim)
  6. 推理算法(Infer)
  7. PyroModule

第一章:基础构件

欢迎来到Pyro的精彩世界(。・∀・)

在本章中,我们将深入探索构建概率模型的核心机制——“基础构件”。

这些构件如同烹饪食谱中的基础食材,是构建概率世界的基本指令单元。

核心理念

设想我们需要建立计算机模型来模拟不确定性事件

例如咖啡店每小时客流量波动,或根据学习时长预测学生考试通过概率。真实世界的这些现象都具有随机性特征。

构建此类模型需要三个关键步骤:

  1. 定义随机事件:“客户数量遵循特定模式的随机分布”
  2. 声明可学习参数:“客户平均数量可能随时间变化,模型需从数据中自主发现规律”
  3. 处理事件集合:“拥有多时段的客户数据,各时段客流量相互独立”

这正是Pyro基础构件的核心价值~

这些特殊函数能将上述概念转化为Python代码,构建Pyro可识别并用于学习的"概率程序"。

三大核心构件包括:

  • pyro.sample:生成随机变量(如"从牌堆抽卡")
  • pyro.param:声明可学习参数(如"调整配方糖量")
  • pyro.plate:声明数据批次的独立性(如"为烤盘每个饼干重复操作")

让我们深入解析每个构件。

1. pyro.sample:引入随机性

pyro.sample是模型引入随机性的核心方式。使用该构件即告知Pyro:“此处需从指定概率分布抽取随机值”

使用范例

调用pyro.sample时需指定随机变量name和分布对象:

import pyro
import pyro.distributions as dist
import torchdef coin_flip_model():# 定义伯努利分布(类似抛硬币:0反面,1正面)# 设置0.5的正面概率coin_prob = torch.tensor(0.5)# 使用pyro.sample从该分布抽取名为"flip"的随机变量flip = pyro.sample("flip", dist.Bernoulli(coin_prob))return flip# 运行简单模型
outcome = coin_flip_model()
print(f"抛硬币结果: {int(outcome)}")

运行机制:执行coin_flip_model()时,pyro.sample会从Bernoulli(0.5)分布抽取0或1的随机值,每次运行可能获得不同结果,如同真实抛硬币!

底层实现原理

pyro.sample并非简单函数调用,而是创建描述随机事件的"消息对象"(包含名称、分布等信息)。

该消息通过称为"poutine effect handler stack"的内部系统处理(详见Poutine),实现执行过程记录,这对推理算法至关重要。

简化执行流程:

在这里插入图片描述

2. pyro.param:声明可学习参数

概率模型常包含需从数据学习的参数。例如模拟偏置硬币时,需根据观测数据学习实际正面概率。

pyro.param正是声明此类可学习参数的核心构件

使用范例

调用pyro.param需指定参数name和初始值(PyTorch张量),Pyro将自动管理参数优化:

import pyro
import pyro.distributions as dist
import torchdef learnable_coin_model():# 声明名为"coin_bias"的可学习参数# 初始值设为0.5(公平硬币假设)# 参数值将由Pyro存储和更新coin_bias = pyro.param("coin_bias", torch.tensor(0.5, requires_grad=True))# 确保偏置值在[0,1]区间(概率约束)# 使用torch.sigmoid将任意实数映射到[0,1]# 约束与分布的深入解析见[Distributions](02_distributions_.md)constrained_bias = torch.sigmoid(coin_bias)# 基于可学习偏置进行伯努利采样flip = pyro.sample("flip", dist.Bernoulli(constrained_bias))return flip# 运行模型仍会获得随机结果
# 参数学习过程将在后续推理算法中实现
_ = learnable_coin_model()# 访问当前参数值
current_bias = pyro.param("coin_bias")
print(f"当前可学习硬币偏置(初始): {current_bias.item()}")

运行机制:通过pyro.param声明参数后,Pyro将识别其为可优化参数。requires_grad=True是PyTorch标准参数,用于梯度追踪,这对优化过程至关重要。

底层实现原理

调用pyro.param时,Pyro会访问全局"参数存储库"

该存储库类似特殊字典,集中管理所有可学习参数。

  • 新参数会被初始化存储

  • 已有参数则返回当前值,实现模型不同组件间的参数共享。

详细机制见参数存储库,简化流程:

在这里插入图片描述

3. pyro.plate:处理独立性与批次

处理多独立数据点时(如100次抛硬币实验),逐次调用pyro.sample效率低下。pyro.plate通过声明代码块的独立性,支持批量处理

使用范例

通过上下文管理器使用pyro.plate,指定name、数据总量及可选批次大小:

import pyro
import pyro.distributions as dist
import torchdef hundred_flips_model():# 声明可学习硬币偏置coin_bias = pyro.param("coin_bias", torch.tensor(0.5, requires_grad=True))constrained_bias = torch.sigmoid(coin_bias)# 声明100次独立抛硬币实验# 'data'为唯一标识名# '100'为独立事件总数with pyro.plate("flips_plate", 100):# 该代码块内的pyro.sample调用将视为100次独立采样# 单行代码实现100个独立"flip"采样点flips = pyro.sample("flip", dist.Bernoulli(constrained_bias))# flips将是包含100次结果的张量return flips# 运行模型
all_outcomes = hundred_flips_model()
print(f"100次抛硬币结果维度: {all_outcomes.shape}")
print(f"前10次结果: {all_outcomes[:10].int().tolist()}")

运行机制pyro.plate实现向量化采样,flips变量成为包含100元素的PyTorch张量,大幅提升大数据集处理效率

支持子采样处理大规模数据:

import pyro
import pyro.distributions as dist
import torchdef mini_batch_flips_model():coin_bias = pyro.param("coin_bias", torch.tensor(0.5, requires_grad=True))constrained_bias = torch.sigmoid(coin_bias)total_flips = 1000  # 假设总样本量batch_size = 100    # 单次处理100样本# 使用带子采样的pyro.platewith pyro.plate("flips_plate", total_flips, subsample_size=batch_size) as ind:# ind提供当前批次索引print(f"处理批次大小: {len(ind)}")# 模拟大数据集子采样flips_batch = pyro.sample("flip", dist.Bernoulli(constrained_bias).expand([len(ind)]).to_independent(1))# 注:expand和to_independent用于匹配批次维度,暂不需深究# 实际应用中可通过ind索引数据子集print(f"采样批次维度: {flips_batch.shape}")_ = mini_batch_flips_model()

运行机制:带subsample_sizepyro.plate会从1000样本中随机抽取100个索引。

Pyro自动处理概率缩放,确保推理算法正确运作,这对大规模数据处理至关重要。

底层实现原理

pyro.plate通过向运行时栈添加特殊"messenger",通知代码块内的pyro.sample处理批次维度和独立性。当使用子采样时,自动应用缩放因子确保推理正确性。具体实现详见Poutine。

总结

本章解析了Pyro三大基础构件:

构件功能定位类比说明
pyro.sample定义随机变量随机抽牌
pyro.param声明可学习参数调整配方成分
pyro.plate声明数据批次的独立性批量处理饼干制作步骤

这些构件是构建概率模型的核心语言

通过组合运用,可实现强大灵活的模型构建。

掌握基础构件后,下一步是理解如何描述各类随机现象,这将引导我们进入概率分布的探索。

http://www.dtcms.com/a/333266.html

相关文章:

  • find命令解读
  • 重塑工业设备制造格局:明远智睿 T113-i 的破局之道
  • 2025北京世界机器人大会:技术、场景、生态实现三重跃迁
  • ARM+OpenPLC 组合详解及经典示例
  • MySQL → SQL → DDL → 表操作 → 数据类型 知识链整理成一份系统的内容
  • 基于 ArcFace/ArcMargin 损失函数的深度特征学习高性能人脸识别解决方案
  • pandas中df.to _dict(orient=‘records‘)方法的作用和场景说明
  • 题解:CF2127D Root was Built by Love, Broken by Destiny
  • CUDA × JetPack 初学者全指南
  • Python工具箱系列(六十四)
  • go语言运算符·关系运算符
  • sql CURRENT_TIMESTAMP
  • 【DSP28335 事件驱动】唤醒沉睡的 CPU:外部中断 (XINT) 实战
  • java注释功能
  • ESP32-C3_TCP
  • Linux操作系统从入门到实战(二十二)命令行参数与环境变量
  • 信刻光盘摆渡系统案例——某省纪委
  • 微服务容错与监控体系设计
  • 生存主义:隐形异变 (Survivalist: Invisible Strain)免安装中文版
  • Leetcode 最小生成树系列(1)
  • 解决zabbix图片中文乱码
  • Mac(二)Homebrew 的安装和使用
  • 前端更改浏览器默认滚动条样式
  • 716SJBH高职院校财务收费系统的设计与实现
  • 25. 移动端-uni-app
  • 【论文解读】DDRNet:深度双分辨率网络在实时语义分割中的结构与原理全面剖析
  • LeetCode 905.按奇偶排序数组
  • 【机器学习深度学习】客观评估主观评估:落地场景权重比例
  • Rust 中 i32 与 *i32 的深度解析
  • 大华相机RTSP无法正常拉流问题分析与解决