当前位置: 首页 > news >正文

deepbayes lecture2:变分推断

变分推断简介

在机器学习领域,贝叶斯推断是一种强大的框架,它允许我们结合先验知识和观测数据来估计模型参数的后验分布。然而,在实际应用中,后验分布的计算往往是难以处理的,这限制了贝叶斯方法的广泛应用。为了解决这个问题,变分推断(Variational Inference,VI)应运而生,它提供了一种近似贝叶斯推断的有效途径。

什么是变分推断?

变分推断是一种用于近似计算复杂后验分布的技术。它的核心思想是将后验推断问题转化为一个优化问题,通过寻找一个与真实后验分布尽可能接近的简单分布来近似它。

具体来说,变分推断的目标是找到一个属于特定分布族 Q 的分布 q(θ),使得 q(θ) 与真实的后验分布 p(θ|x) 之间的 Kullback-Leibler (KL) 散度 最小化。KL 散度是一种衡量两个概率分布差异的指标,KL散度越小,表示两个分布越接近。

用公式表达如下:

F ( q ) : = K L ( q ( θ ) ∥ p ( θ ∣ x ) ) → min ⁡ q ( θ ) ∈ Q F(q) := KL(q(\theta) \parallel p(\theta | x)) \rightarrow \min_{q(\theta)\in Q} F(q):=KL(q(θ)p(θx))q(θ)Qmin

为什么需要变分推断?

完全贝叶斯推断在训练阶段需要计算后验分布,在测试阶段需要对后验分布进行积分。但是,只有在简单的共轭模型中才能分析地计算后验分布。在大多数情况下,后验计算是难以处理的。因此,需要近似推理。

变分推断MCMC 都是近似推理方法。其中:

  • 变分推断(Variational Inference):

    • 近似:p(θ | x) ≈ q(θ) ∈ Q
    • 有偏差
    • 速度更快、更具可扩展性
  • MCMC:

    • 从非标准化 p(θ | x) 采样
    • 无偏差
    • 需要大量样本

变分推断的数学原理

为了理解变分推断的数学原理,我们需要引入 证据下界 (Evidence Lower Bound, ELBO) 的概念。ELBO 是对数证据 log p(x) 的一个下界,它与 KL 散度之间存在以下关系:

log ⁡ p ( x ) = L ( q ( θ ) ) + K L ( q ( θ ) ∥ p ( θ ∣ x ) ) \log p(x) = \mathcal{L}(q(\theta)) + KL(q(\theta) \parallel p(\theta | x)) logp(x)=L(q(θ))+KL(q(θ)p(θx))

其中:

  • log p(x) 是对数证据,表示观测数据的边缘概率,它是一个常数,不依赖于 q(θ)。
  • L(q(θ)) 是 ELBO,它定义为 L ( q ( θ ) ) = ∫ q ( θ ) log ⁡ [ p ( x , θ ) q ( θ ) ] d θ 。 \mathcal{L}(q(\theta)) = \int q(\theta) \log \left[\frac{p(x, \theta)}{q(\theta)}\right] d\theta。 L(q(θ))=q(θ)log[q(θ)p(x,θ)]dθ
  • KL(q(θ) || p(θ | x)) 是 q(θ) 与真实后验分布 p(θ|x) 之间的 KL 散度。

由于 KL 散度始终为非负值,因此 ELBO 是 log p(x) 的一个下界。最大化 ELBO 等价于最小化 KL 散度,从而找到一个与真实后验分布最接近的近似分布 q(θ)。

所以原始问题:

F ( q ) : = K L ( q ( θ ) ∥ p ( θ ∣ x ) ) → min ⁡ q ( θ ) ∈ Q F(q) := KL(q(\theta) \parallel p(\theta | x)) \rightarrow \min_{q(\theta)\in Q} F(q):=KL(q(θ)p(θx))q(θ)Qmin

等价于:

L ( q ( θ ) ) → max ⁡ q ( θ ) ∈ Q L(q(\theta)) \rightarrow \max_{q(\theta)\in Q} L(q(θ))q(θ)Qmax

我们可以将 ELBO 写成两种等价形式:

L ( q ( θ ) ) = E q ( θ ) [ log ⁡ p ( x ∣ θ ) ] − K L ( q ( θ ) ∥ p ( θ ) ) ( 数据项目 − 正则化项 ) \mathcal{L}(q(\theta)) = \mathbb{E}_{q(\theta)}[\log p(x | \theta)] - KL(q(\theta) \parallel p(\theta)) (数据项目-正则化项) L(q(θ))=Eq(θ)[logp(xθ)]KL(q(θ)p(θ))(数据项目正则化项) L ( q ( θ ) ) = log ⁡ p ( x ) − K L ( q ( θ ) ∥ p ( θ ∣ x ) ) \mathcal{L}(q(\theta)) = \log p(x) - KL(q(\theta) \parallel p(\theta | x)) L(q(θ))=logp(x)KL(q(θ)p(θx))

如何进行变分推断?

变分推断的关键在于选择合适的 变分分布族 Q。常见的变分分布族包括:

  • 平均场近似 (Mean Field Approximation):假设变分分布可以分解为各个参数的独立分布的乘积,即 q(θ) = ∏j qj(θj)。
  • 参数化近似 (Parametric Approximation):假设变分分布具有某种特定的参数形式,例如高斯分布或指数分布,即 q(θ) = q(θ | λ),其中 λ 是变分参数。

选择变分分布族后,我们需要优化 ELBO,找到最优的变分分布 q(θ)。常用的优化方法包括:

  • 坐标上升变分推断 (Coordinate Ascent Variational Inference, CAVI):这是一种迭代优化算法,它依次更新每个参数的变分分布,直到 ELBO 收敛。
  • 随机梯度变分推断 (Stochastic Gradient Variational Inference, SGVI):这是一种基于随机梯度的优化算法,它适用于大规模数据集和复杂模型。

平均场近似

平均场近似假定变分分布可以分解为:

q ( θ ) = ∏ j = 1 m q j ( θ j ) q(\theta) = \prod_{j=1}^m q_j(\theta_j) q(θ)=j=1mqj(θj)

这意味着我们假设 θ 的各个分量是独立的。虽然这是一种很强的假设,但在实践中,它通常能够提供良好的近似效果。

在平均场近似下,我们可以使用坐标上升算法来优化 ELBO。坐标上升算法的思想是,每次固定其他分量,然后优化一个分量。重复这个过程,直到收敛。

平均场变分推断算法:

  1. 初始化 q j ( θ j ) = 1 Z j exp ⁡ ( E q i ≠ j log ⁡ p ( x , θ ) ) q_j (\theta_j) = \frac{1}{Z_j} \exp\left(\mathbb{E}_{q_{i\neq j}} \log p(x, \theta)\right) qj(θj)=Zj1exp(Eqi=jlogp(x,θ))
  2. 迭代:
  • 更新每个因子 q1,…, qm:
    q j ( θ j ) = 1 Z j exp ⁡ ( E q i ≠ j log ⁡ p ( x , θ ) ) q_j (\theta_j) = \frac{1}{Z_j} \exp\left(\mathbb{E}_{q_{i\neq j}} \log p(x, \theta)\right) qj(θj)=Zj1exp(Eqi=jlogp(x,θ))

  • 计算 ELBO L(q(θ))

  • 重复直到 ELBO 收敛

  • 其中 Z j Z_j Zj是归一化常数

总结

变分推断是一种强大的近似贝叶斯推断方法,它通过优化 ELBO 来寻找与真实后验分布最接近的近似分布。变分推断具有计算效率高、可扩展性强等优点,在机器学习领域得到了广泛应用。希望这篇博客文章能够帮助你理解变分推断的基本概念和原理。

Reference

[deepbayes-2019/lectures/day1/2. Dmitry Vetrov - Variational inference.pdf at master · bayesgroup/deepbayes-2019 · GitHub

相关文章:

  • “详规一张图”——新加坡土地利用数据
  • Open3D 对点云进行去噪(下采样、欧式聚类分割)01
  • 基于算法竞赛的c++编程(25)指针简单介绍和简单应用
  • 【Vue】scoped+组件通信+props校验
  • DingDing机器人群消息推送
  • 二维FDTD算法仿真
  • JVM如何优化
  • Qt学习及使用_第1部分_认识Qt---Qt开发基本流程
  • AirPosture | 通过 AirPods 矫正坐姿
  • while/do while/for循环几个小细节
  • 免费数学几何作图web平台
  • React中子传父组件通信操作指南
  • JavaScript的ArrayBuffer与C++的malloc():两种内存管理方式的深度对比
  • Linux进程信号(一)
  • LLMs 系列实操科普(2)
  • Spring Boot面试题精选汇总
  • 如何做好一份技术文档?从规划到实践的完整指南
  • React从基础入门到高级实战:React 实战项目 - 项目五:微前端与模块化架构
  • ubuntu22.04 安装docker 和docker-compose
  • 安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖
  • 网站seo插件/电商seo
  • 济南微信网站制作/网址搜索引擎
  • 新乡建设公司网站/公司网站制作费用
  • 网站建设南阳/宁波网站建设推广公司价格
  • 中国开发网站的公司/网络整合营销方案ppt
  • 做网站用asp div代码/百度一下了你就知道官网