JAX study notes[17]
文章目录
- conditional expectation
- references
conditional expectation
Monte Carlo Method is a popular method to achieve numerical calculation with randomly sampling and static simulation. it builds on law of large numbers.
import jax
import jax.numpy as jnp
from jax import randomdef exp_inverse_sample(key, lam, n_samples):u = random.uniform(key, (n_samples,))return -jnp.log(1 - u) / lamkey = random.PRNGKey(42)
samples = exp_inverse_sample(key, lam=0.5, n_samples=10_000)
print("前5个样本:", samples[:5])
import jax
import jax.numpy as jnp
from jax import random# 参数设置
lambda_Y = 0.5 # Y ~ Exp(0.5)
y_observed = 2.0 # 观测值
n_samples = 10_000# 从条件分布 X|Y ~ Exp(Y) 采样
def conditional_sample_X(key, Y):U = random.uniform(key) # 使用单个键return -jnp.log(1 - U) / Y# 生成样本
key = random.PRNGKey(42)
keys = random.split(key, n_samples) # 形状 (n_samples, 2)
X_samples = jax.vmap(conditional_sample_X)(keys, jnp.full((n_samples,), y_observed))# 计算条件期望
cond_expect = jnp.mean(X_samples)
print(f"E[X|Y={y_observed}] 蒙特卡洛估计: {cond_expect:.4f}")
print(f"理论值: {1/y_observed:.4f}")
references
- deepseek