当前位置: 首页 > news >正文

[Pyro概率编程] 推理算法Infer | 随机变分推断SVI | MCMC采样机制

第六章:推理算法(Infer)

欢迎回来(。・∀・)

通过前几章的学习,我们已经对于 参数空间 我们了解了:

  • 第一章:基础构件 介绍了可学习参数pyro.param
  • 第三章:参数存储库 解析了参数管理机制
  • 第五章:优化器 揭示了参数调整的核心算法

但优化器如同需要训练计划的教练,必须明确优化目标与评估方式。推理算法正是为此而生,它们为优化过程提供系统性解决方案

核心理念

假设我们构建了一个描述数据生成机制的复杂概率模型(“配方”),其中包含未知参数或隐变量(“秘制调料”)。我们已收集实际观测数据(“成品菜肴”),需要逆向推导最优参数组合

推理算法就是完成这一逆向工程的"主厨",通过协调模型、数据与优化器,系统性地探索参数空间以找到最优解。

Pyro提供两大核心推理方法:

  1. 随机变分推断(SVI
    通过迭代优化近似后验分布的参数,快速寻找近似最优解

  2. 马尔可夫链蒙特卡洛(MCMC
    通过系统采样探索参数空间,获取精确后验分布(计算成本较高)

关键组件

推理算法整合以下核心元素:

  • 模型(Model):描述数据与隐变量关系的概率程序
  • 引导函数(Guide,仅SVI):近似隐变量后验分布的简化模型
  • 观测数据:用于逆向推导的实际观测结果
  • 损失函数:衡量当前参数拟合程度的量化指标
  • 优化器:执行参数更新的优化算法

随机变分推断(SVI)实践

SVI将学习问题转化为优化问题,通过最小化证据下界(ELBO)实现高效推理:

实施步骤

  1. 定义概率模型:完整描述数据生成过程
  2. 构建引导函数:声明可学习参数,匹配模型隐变量结构
  3. 选择优化器:如Adam,配置学习率等超参数
  4. 指定损失函数:常用Trace_ELBO
  5. 实例化SVI对象:整合模型、引导函数、优化器与损失函数
  6. 执行训练循环:迭代调用step()更新参数

以硬币偏置学习为例(10次投掷观测到8次正面):

import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO
import torch# 观测数据(8次正面,2次反面)
observed_data = torch.tensor([1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0])# 1. 定义生成模型
def coin_model(data):p = pyro.sample("p", dist.Beta(1.0, 1.0))  # 均匀先验with pyro.plate("data_loop", len(data)):pyro.sample("obs", dist.Bernoulli(p), obs=data)# 2. 构建引导函数
def coin_guide(data):alpha = pyro.param("guide_alpha", torch.tensor(1.0), constraint=dist.constraints.positive)beta = pyro.param("guide_beta", torch.tensor(1.0),constraint=dist.constraints.positive)pyro.sample("p", dist.Beta(alpha, beta))# 初始化环境
pyro.clear_param_store()
adam = optim.Adam({"lr": 0.01})
elbo = Trace_ELBO()
svi = SVI(coin_model, coin_guide, adam, elbo)# 执行训练
for i in range(1000):loss = svi.step(observed_data)if i % 100 == 0:print(f"Iter {i}, Loss: {loss:.2f}")# 输出学习结果
alpha_learned = pyro.param("guide_alpha").item()
beta_learned = pyro.param("guide_beta").item()
print(f"学习参数: alpha={alpha_learned:.2f}, beta={beta_learned:.2f}")
print(f"推断正面概率: {alpha_learned/(alpha_learned+beta_learned):.2f}")  # 预期接近0.8

实现解析

  1. coin_model定义硬币偏置p的Beta先验及观测数据生成过程
  2. coin_guide使用可学习的Beta分布参数alphabeta近似后验
  3. SVI通过ELBO损失指导参数优化,最终alphabeta收敛至近似真实后验的参数

马尔可夫链蒙特卡洛(MCMC)实践

马尔可夫链

马尔可夫链是一种数学模型,描述一个系统的状态变化仅依赖于当前状态,与过去的历史无关

生活例子
假设每天的天气只有“晴”或“雨”两种状态,且今天的天气仅由昨天的天气决定:

  • 昨天是晴天,今天有70%概率继续晴,30%概率转雨;
  • 昨天是雨天,今天有50%概率继续雨,50%概率转晴。
    这种“明天的天气只取决于今天”的特性就是马尔可夫链的核心。

MCMC通过马尔可夫链采样直接获取后验分布样本,无需引导函数:

实施步骤

  1. 定义概率模型:同SVI的定义方式
  2. 选择MCMC核:如NUTS(No-U-Turn Sampler)算法
  3. 配置MCMC参数:采样次数、预热步数等
  4. 执行采样过程:运行run()方法获取后验样本
  5. 分析采样结果:计算统计量评估后验分布

延续硬币偏置案例:

from pyro.infer import MCMC, NUTS# 定义生成模型(同SVI案例)
def coin_model_mcmc(data):p = pyro.sample("p", dist.Beta(1.0, 1.0))with pyro.plate("data_loop", len(data)):pyro.sample("obs", dist.Bernoulli(p), obs=data)# 执行MCMC采样
pyro.clear_param_store()
nuts_kernel = NUTS(coin_model_mcmc)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(observed_data)# 获取后验样本
mcmc_samples = mcmc.get_samples()
p_posterior = mcmc_samples["p"]
print(f"后验均值: {p_posterior.mean().item():.2f}")
print(f"后验标准差: {p_posterior.std().item():.2f}")

实现解析

  1. NUTS算法通过哈密顿动力学高效探索参数空间
  2. 经过500步预热使马尔可夫链收敛,后续1000步采样获取稳定后验分布
  3. 样本统计量显示p的估计值及其不确定性

SVI与MCMC对比指南

特性SVIMCMC
输出形式近似分布的参数后验分布的真实样本
计算效率适合大规模数据与复杂模型高维模型计算成本较高
引导函数需求必需无需
实时性支持在线学习通常需要完整数据集
精度依赖引导函数近似能力渐进精确
典型应用神经网络结合概率模型小数据精确推断

底层实现解析

SVI工作流程

  1. 引导函数追踪:记录可学习参数
  2. 联合执行模型:通过Poutine效果处理器计算模型与引导函数的对数概率
  3. ELBO计算:构建证据下界作为损失函数
  4. 反向传播:计算参数梯度
  5. 优化器更新:调用PyroOptim更新参数存储库

在这里插入图片描述

MCMC采样机制

  1. 预热阶段:调整采样步长等参数使链收敛
  2. 采样阶段:基于哈密顿动力学提案,按Metropolis-Hastings准则接受/拒绝
  3. 样本收集:存储通过检验的参数值构建后验分布
# 简化MCMC核心逻辑
for _ in range(预热步数):生成参数提案计算接受概率按概率接受/拒绝提案for _ in range(采样步数):生成新提案计算接受概率收集接受样本

总结

本章解析了Pyro两大推理范式:

  • SVI:通过变分近似实现高效参数学习,适合与深度学习架构整合
  • MCMC:通过精确采样获取后验分布,适合小规模精确推断

后续章节将探讨如何通过PyroModule将概率模型与PyTorch模块无缝结合,构建更复杂的混合架构。

http://www.dtcms.com/a/333817.html

相关文章:

  • linux 设备驱动的分层思想
  • MySQL的学习笔记
  • Python 常用库速查手册
  • 小红书帖子评论的nodejs爬虫脚本
  • C++编程学习(第24天)
  • 数据结构与算法p4
  • Eclipse:关闭项目
  • 【121页PPT】锂膜产业MESERP方案规划建议(附下载方式)
  • Git、JSON、MQTT
  • ramdisk内存虚拟盘(一)——前世今生
  • 嵌入式第二十九课!!!回收子进程资源空间函数与exec函数
  • SurperSet柱状图排序失效问题解决
  • 移动板房的网络化建设
  • python中的reduce函数
  • FTP定时推拉数据思考
  • 深入理解 Python 闭包:从原理到实践
  • AI - MCP 协议(一)
  • NY232NY236美光固态闪存NY240NY241
  • Dummy步进电机驱动使用和相关问题
  • 疏老师-python训练营-Day46通道注意力(SE注意力)
  • 高通vendor app访问文件
  • 【使用三化总结大模型基础概念】
  • 淘宝/天猫店铺商品搜索利器:taobao.item_search_shop API返回值详解
  • 【秋招笔试】2025.08.15饿了么秋招机考-第一题
  • 嵌入式linux学习 -- 进程和线程
  • CIAIE 2025上海汽车内外饰展观察:从美学到功能的产业跃迁
  • Redis 启动时出现 “Bad file format reading the append only file“ 错误
  • 【万字精讲】 左枝清减·右枝丰盈:C++构筑的二叉搜索森林
  • office2016常见故障解决方法
  • 第七十一章:AI的“个性定制服务”:微调 LLM vs 微调 Diffusion 模型——谁是“魔改之王”?