当前位置: 首页 > news >正文

使用 BayesFlow 通过神经网络简化贝叶斯推断(一)

本篇文章Easier Bayesian Inference with Neural Networks using BayesFlow (Code Included)是介绍BayesFlow 大致功能的开篇。


文章目录

  • 1 什么是 BayesFlow?
    • 1.1 工作原理:BayesFlow 工作流
    • 1.2 组件
    • 1.3 核心功能
  • 2 实际应用
    • 2.1 简单入门
    • 2.2 专家级定制
    • 2.3 使用 BayesFlow 进行贝叶斯线性回归:摊销推理的实践入门
      • 2.3.1 为什么选择摊销推理?
      • 2.3.2 核心架构:摘要网络与推理网络
      • 2.3.3 分步实现
      • 2.3.4 定义生成模型
      • 2.3.5 通过适配器准备数据
      • 2.3.6 构建神经网络
        • 2.3.6.1 摘要网络
        • 2.3.6.2 推理网络
    • 2.4 连接所有组件:摊销器
    • 2.5 后验估计
    • 2.6 结果可视化


贝叶斯推断为不确定性下的推理、复杂系统建模以及基于观测数据进行预测提供了一种有原则且强大的方法。然而,尽管贝叶斯建模优雅,但它常常遇到严重的计算障碍:

后验分布通常难以处理。

模型验证和比较需要重复推断。

基于仿真的工作流(例如,校准、恢复、敏感性分析)变得慢得令人望而却步。

这种计算成本传统上限制了贝叶斯工作流的实际应用——直到 BayesFlow 的出现。

1 什么是 BayesFlow?

BayesFlow 是一个开源的 Python 库,旨在利用摊销神经网络加速和扩展贝叶斯推断。通过训练神经网络“学习”逆问题(从数据推断参数)或正向模型(从参数生成数据),BayesFlow 可以在初始训练后实现近乎即时的推断——通常在毫秒级完成。

核心思想: 一次性投入计算资源训练神经网络,然后将其重复用于数千次快速推断

BayesFlow 基于 TensorFlow 构建,无缝支持 GPU/TPU 加速,并与 TensorFlow Probability 集成,以实现灵活的先验和潜在变量。

1.1 工作原理:BayesFlow 工作流

BayesFlow 的核心是一个形式化、模块化的架构,它模仿了传统贝叶斯工作流的关键组件,但通过神经网络近似器对其进行了超强赋能。其工作原理如下:

1.2 组件

  1. 模拟 + 先验:定义你的生成模型(例如,流行病学中的 SIR 模型)。
  2. 配置器:准备用于训练的数据(例如,归一化、嵌入)。
  3. 神经网络
    • 摘要网络:将原始模拟数据或参数压缩为密集嵌入。
    • 后验网络:学习从数据到参数的逆映射。
    • 似然网络:学习从参数到数据的正向映射。

这些网络可以组合使用,也可以根据你的任务(后验估计、似然模拟、模型比较等)独立使用。

1.3 核心功能

BayesFlow 支持现代贝叶斯工作流的四个关键功能:

  1. Amortized 后验估计
    一次训练,多次推断。实现跨数据集的完整后验快速估计。
    → 解决逆问题。
  2. Amortized 似然估计
    模拟复杂模拟器以估计似然,无需重新运行。
    → 解决正向问题。
  3. Amortized 模型比较
    根据模型解释数据的能力对模型进行分类或排序——使用学习到的后验和似然。
    → 计算贝叶斯证据和预测准确性。
  4. 模型误设定检测
    诊断你的模拟器何时不再代表现实——即使推断“有效”。
    → 避免自信地犯错。

2 实际应用

BayesFlow 不仅仅是理论——它已被部署到广泛的领域:

  • 流行病学:使用基于模拟的 SIR 模型进行疾病传播建模。
  • 神经科学与精神病学:认知和计算模型的参数恢复。
  • 地震学:地震建模中的高维逆问题。
  • 粒子物理学:复杂模拟器的快速代理模型。
  • 航空航天、MEMS、风力涡轮机:不确定性下的工程设计。

简而言之:如果你有一个模拟器,你就可以使用 BayesFlow。

2.1 简单入门

以下是入门的简单方法:

import bayesflow as bfworkflow = bf.BasicWorkflow(inference_network=bf.networks.CouplingFlow(),summary_network=bf.networks.TimeSeriesNetwork(),inference_variables=["parameters"],summary_variables=["observables"],simulator=bf.simulators.SIR()
)
history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)
diagnostics = workflow.plot_default_diagnostics(test_data=300)

无需构建复杂的训练循环——BayesFlow 处理从模拟到诊断的所有环节。

2.2 专家级定制

BayesFlow 提供:

  • 一个用户友好的 API,适用于应用研究人员。
  • 一个模块化设计,供机器学习专家插入自定义网络、训练方案或推断策略。
  • 开箱即用的默认设置,适用于许多基于模拟的模型。

无论你是为认知建模构建管道,还是为航空航天设计调整代理模型,BayesFlow 都能适应你的工作流。

2.3 使用 BayesFlow 进行贝叶斯线性回归:摊销推理的实践入门

欢迎来到我们使用 BayesFlow 的第一个演练——一个用于通过神经网络进行摊销贝叶斯推断的强大库。在本教程中,我们将使用一个简单的线性回归示例来探索摊销后验估计的基本概念,并演示 BayesFlow 的模块化架构。

我们将通过使用 BayesFlow 的低级 API 来保持透明,从而完全控制每个组件——从模拟器创建到网络架构。如果你刚开始学习并想了解内部工作原理,这将是完美的选择。

2.3.1 为什么选择摊销推理?

传统贝叶斯推断中,我们根据观测数据估计模型参数的后验分布。这通常需要计算成本高昂的方法,如 MCMC 或变分推断——对于每个新数据集都是如此。

但是,如果我们能学会推断呢?

这就是摊销贝叶斯推断的切入点:我们不是为每个新数据集从头开始计算后验,而是训练一个神经网络学习一个函数,该函数直接将数据映射到后验估计。一旦训练完成,这种方法就可以对新数据集进行即时推断

这在高吞吐量、实时或基于模拟的推断设置中尤其有价值。

2.3.2 核心架构:摘要网络与推理网络

我们的 BayesFlow 模型由两个核心网络组成:

  • 摘要网络:将可变长度的输入数据(如观测值)转换为固定长度的嵌入。
  • 推理网络:使用条件生成模型(通常是可逆神经网络)基于此嵌入学习从近似后验中采样。

这些网络共同学习“反转”一个从潜在参数生成数据的模拟器。

2.3.3 分步实现

让我们首先导入必要的库并设置 BayesFlow 环境。

import numpy as np
from pathlib import Path
import keras
import bayesflow as bfnp.set_printoptions(suppress=True)

2.3.4 定义生成模型

我们首先为基本线性回归模型定义似然

def likelihood(beta, sigma, N):x = np.random.normal(0, 1, size=N)y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N)return dict(y=y, x=x)

现在,定义我们模型参数的先验

def prior():beta = np.random.normal([2, 0], [3, 1])sigma = np.random.gamma(1, 1)return dict(beta=beta, sigma=sigma)

为了实现不同数据大小的摊销,定义一个元函数来采样数据集大小:

def meta():N = np.random.randint(5, 15)return dict(N=N)

现在,让我们将上述所有内容封装在一个 BayesFlow 模拟器中:

simulator = bf.simulators.make_simulator([prior, likelihood], meta_fn=meta)

从模拟器中采样:

sim_draws = simulator.sample(500)

2.3.5 通过适配器准备数据

BayesFlow 提供灵活的适配器管道来准备用于训练的原始模拟数据。

adapter = (bf.Adapter().broadcast("N", to="x").as_set(["x", "y"]).constrain("sigma", lower=0).standardize(exclude=["N"]).sqrt("N").convert_dtype("float64", "float32").concatenate(["beta", "sigma"], into="inference_variables").concatenate(["x", "y"], into="summary_variables").rename("N", "inference_conditions")
)

此适配器执行:

  • 上下文变量([N])的广播
  • 标准化(排除常量)
  • 维度检查
  • 连接和重塑

运行适配器:

processed_draws = adapter(sim_draws)

检查形状:

print(processed_draws["summary_variables"].shape)
print(processed_draws["inference_variables"].shape)
print(processed_draws["inference_conditions"].shape)

2.3.6 构建神经网络

2.3.6.1 摘要网络

由于我们的数据是置换不变的(顺序无关紧要),我们使用 SetTransformerDeepSet 架构从 ([x], [y]) 观测值中学习有意义的嵌入。

summary_net = bf.networks.DeepSet(input_shape=(None, 2), output_dim=64)
2.3.6.2 推理网络

我们将使用 BayesFlow 可逆网络来建模后验分布:

inference_net = bf.networks.InvertibleNetwork(n_params=3, num_coupling_layers=6)

2.4 连接所有组件:摊销器

BayesFlow 提供了一个方便的 Amortizer 类,它组合了所有组件。

amortizer = bf.amortizers.AmortizedPosterior(summary_net=summary_net,inference_net=inference_net
)

使用 Keras 风格的回调进行编译和训练:

amortizer.compile(optimizer="adam")
amortizer.train(processed_draws, epochs=30, batch_size=64)

2.5 后验估计

训练完成后,我们可以为任何新数据集推断后验样本:

test_data = adapter(simulator.sample(1))
posterior_samples = amortizer.sample(test_data["summary_variables"],conditions=test_data["inference_conditions"],n_samples=1000)

2.6 结果可视化

BayesFlow 包含方便的诊断工具来可视化结果:

bf.diagnostics.plots.pairs_samples(samples=posterior_samples,variable_names=[r"$\beta_0$", r"$\beta_1$", r"$\sigma$"]
)
http://www.dtcms.com/a/362035.html

相关文章:

  • C扩展4:X宏(X-MACRO)
  • JS循环机制
  • IS-IS的原理
  • Java超卖问题
  • MySQL安装与使用指南
  • 【读论文】量子关联增强双梳光谱技术
  • 力扣404 代码随想录Day15 第三题
  • 故障排查指南:理解与解决 “No route to host“ 错误
  • NOSQL——Redis
  • MySQL基础知识保姆级教程(四)视图与约束
  • 浅谈中断控制器:从 IRQ 到 IRR、IMR、In-Service Register
  • 软考-操作系统-错题收集(3)文件系统的索引节点结构
  • 【前端】《手把手带你入门前端》前端的一整套从开发到打包流程, 这篇文章都会教会你;什么是vue,Ajax,Nginx,前端三大件?
  • ComPE for win 纯净的PE系统
  • 软考中级数据库系统工程师学习专篇(67、数据库恢复)
  • Spring Security 深度学习(四): 会话管理与CSRF防护
  • 2025 数字化转型期,值得关注的 10 项高价值证书解析
  • Linux笔记---计算机网络概述
  • 视频动作识别模型-C3D
  • 线程池项目代码细节5(解决linux死锁问题)
  • 关系型数据库——GaussDB的简单学习
  • 《投资-43》- 自然=》生物=》人类社会=》商业=》金融=》股市=》投资的共同逻辑:生存竞争与进化论
  • 前端实现查询数据【导出】功能
  • 自制扫地机器人(二) Arduino 机器人避障设计——东方仙盟
  • A股大盘数据-20250901 分析
  • 设计模式:代理模式(Proxy Pattern)
  • HOW - 前端团队组长提升(沟通篇)
  • kubectl-etcd
  • RSA的CTF题目环境和做题复现第1集
  • nacos微服务介绍及环境搭建