当前位置: 首页 > news >正文

daily notes[56]

文章目录

  • the distribution function of the random vector
  • the density funtion of the random vector
  • to get the united density and the edge endsity from Python code with JAX
  • conditional probability density function
  • jax.numpy
    • jax.numpy.linspace
    • jax.numpy.arange
  • jax.typing
  • Diagonal Matrix
  • references

the distribution function of the random vector

在这里插入图片描述
the application of Monte Carlo integration ,which is much better than using trapezoid integration, with JAX for multidimensional integration approximately calculate integartion through random sampling.

∫abf(x)dx≈b−aN∑i=1Nf(xi),xi∼U(a,b)\int_a^b f(x) \, dx \approx \frac{b-a}{N} \sum_{i=1}^N f(x_i), \quad x_i \sim \mathcal{U}(a, b) abf(x)dxNbai=1Nf(xi),xiU(a,b)

import jax
import jax.numpy as jnp
from jax import randomdef monte_carlo_integrate(f, a, b, key, n_samples=10000):"""计算一维积分 ∫[a, b] f(x) dx"""x_samples = random.uniform(key, (n_samples,), minval=a, maxval=b)f_values = f(x_samples)return (b - a) * jnp.mean(f_values)# 示例:积分 sin(x) 在 [0, π]
key = random.PRNGKey(42)
f = lambda x: jnp.sin(x)
result = monte_carlo_integrate(f, 0, jnp.pi, key)
print("∫sin(x)dx ≈", result)  # 理论值: 2.0

multidimensional integration :
∫a1b1∫a2b2∫a3b3f(x1,x2,x3)dx1dx2dx3≈VN∑i=1Nf(xi)\int_{a_1}^{b_1} \int_{a_2}^{b_2} \int_{a_3}^{b_3} f(x_1, x_2, x_3) \, dx_1 dx_2 dx_3 \approx \frac{V}{N} \sum_{i=1}^N f(\mathbf{x}_i) a1b1a2b2a3b3f(x1,x2,x3)dx1dx2dx3NVi=1Nf(xi)
V=∏i=13(bi−ai)V = \prod_{i=1}^3 (b_i - a_i)V=i=13(biai)是积分区域体积

def monte_carlo_3d(f, a, b, key, n_samples=10000):"""计算三维积分 ∫[a1, b1]∫[a2, b2]∫[a3, b3] f(x1, x2, x3) dx1 dx2 dx3a, b: 各维度的积分限,如 a=[a1, a2, a3], b=[b1, b2, b3]"""x_samples = random.uniform(key, (n_samples, 3), minval=a, maxval=b)f_values = jax.vmap(f)(x_samples)  # 向量化计算volume = jnp.prod(jnp.array(b) - jnp.array(a))return volume * jnp.mean(f_values)# 示例:积分 x1 + x2 + x3 在 [0,1]×[0,1]×[0,1]
key = random.PRNGKey(42)
f = lambda x: x[0] + x[1] + x[2]
result = monte_carlo_3d(f, [0, 0, 0], [1, 1, 1], key)
print("∫∫∫(x1+x2+x3)dx1dx2dx3 ≈", result)  # 理论值: 1.5
from jax.scipy.stats import multivariate_normal# 定义三维高斯分布
mean = jnp.array([0.0, 1.0, -1.0])
cov = jnp.array([[1.0, 0.3, 0.1],[0.3, 1.0, 0.2],[0.1, 0.2, 1.0]
])# 联合PDF
joint_pdf = lambda x: multivariate_normal.pdf(x, mean=mean, cov=cov)# 蒙特卡洛计算CDF
def joint_cdf(a, b, c, key, n_samples=10000):# 在 [-5, a] x [-5, b] x [-5, c] 内均匀采样(近似无限积分)x_samples = random.uniform(key, (n_samples, 3), minval=jnp.array([-5.0, -5.0, -5.0]),maxval=jnp.array([a, b, c]))pdf_values = jax.vmap(joint_pdf)(x_samples)volume = (a - (-5)) * (b - (-5)) * (c - (-5))return volume * jnp.mean(pdf_values)key = random.PRNGKey(42)
print("P(X1≤0.5, X2≤1.5, X3≤-0.5) ≈", joint_cdf(0.5, 1.5, -0.5, key))
方法优点缺点
蒙特卡洛 (JAX)易并行化,支持高维,GPU加速结果有随机性
梯形积分确定性结果高维计算慢,内存占用大
SciPy quad/dblquad高精度不支持自动微分/GPU

the density funtion of the random vector

在这里插入图片描述
the density function is the partial derivative of X as follows:
在这里插入图片描述

to get the united density and the edge endsity from Python code with JAX

  • the united density
    we can explain the united density with an example which is three-dimensional Gaussian distrubution.
import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal# 定义均值向量和协方差矩阵
mean = jnp.array([0.0, 1.0, -1.0])  # 3维均值
cov = jnp.array([[1.0, 0.3, 0.1],[0.3, 1.0, 0.2],[0.1, 0.2, 1.0]
])  # 3x3 对称正定协方差矩阵# 联合概率密度函数
def joint_pdf(x):return multivariate_normal.pdf(x, mean=mean, cov=cov)# 测试点
x_test = jnp.array([0.5, 1.5, -0.5])
print("联合概率密度:", joint_pdf(x_test))
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
联合概率密度: 0.051931586
# 检查对称性
is_symmetric = jnp.allclose(cov, cov.T)
print("协方差矩阵是否对称:", is_symmetric)# 检查正定性(通过Cholesky分解)
try:jnp.linalg.cholesky(cov)print("协方差矩阵正定")
except:print("协方差矩阵非正定!需修正")

在这里插入图片描述
it is the best way of calculating the integral with JAX to use trapezoid method because that JAX has inability to provide common API function such as SciPy's quad.
在这里插入图片描述

import jax.numpy as jnpdef trapz(y, x):"""手动实现梯形积分,兼容 JAX"""return jnp.sum((y[:-1] + y[1:]) / 2 * jnp.diff(x))# 示例:积分 sin(x) 在 [0, π]
x = jnp.linspace(0, jnp.pi, 1000)
y = jnp.sin(x)
result = trapz(y, x)
print(result)  # 应接近 2.0

you can also use jax.scipy.integrate.trapezoid.

import jax.numpy as jnpfrom jax.scipy.integrate import trapezoidx = jnp.linspace(0, jnp.pi, 1000)
y = jnp.sin(x)
result = trapezoid(y, x)
print(result)  # 输出 2.0

a complete example for computing the edge endsity be shown as follows:

from jax.scipy.integrate import trapezoid
import jax
import jax.numpy as jnpimport jax
from jax.scipy.stats import multivariate_normal# 定义三维高斯参数
mean = jnp.array([0.0, 1.0, -1.0])
cov = jnp.array([[1.0, 0.3, 0.1],[0.3, 1.0, 0.2],[0.1, 0.2, 1.0]
])# 联合概率密度函数
def joint_pdf(x):return multivariate_normal.pdf(x, mean=mean, cov=cov)# 计算边缘密度 p(x1) = ∫∫ p(x1, x2, x3) dx2 dx3
def marginal_pdf1(x1, n_samples=100):x2 = jnp.linspace(-5, 5, n_samples)  # 积分范围近似x3 = jnp.linspace(-5, 5, n_samples)X2, X3 = jnp.meshgrid(x2, x3)# 向量化计算积分integrand = lambda x2, x3: joint_pdf(jnp.array([x1, x2, x3]))y = jax.vmap(jax.vmap(integrand))(X2, X3)# 双重梯形积分dx2 = x2[1] - x2[0]dx3 = x3[1] - x3[0]return trapezoid(trapezoid(y, dx=dx2), dx=dx3)# 计算边缘密度 p(x2) = ∫∫ p(x1, x2, x3) dx1 dx3
def marginal_pdf2(x2, n_samples=100):x1 = jnp.linspace(-5, 5, n_samples)  # 积分范围近似x3 = jnp.linspace(-5, 5, n_samples)X1, X3 = jnp.meshgrid(x1, x3)# 向量化计算积分integrand = lambda x1, x3: joint_pdf(jnp.array([x1, x2, x3]))y = jax.vmap(jax.vmap(integrand))(X1, X3)# 双重梯形积分dx1 = x1[1] - x1[0]dx3 = x3[1] - x3[0]return trapezoid(trapezoid(y, dx=dx1), dx=dx3)# 测试
print("边缘密度 p(x2=0.6):", marginal_pdf2(0.6))
print("边缘密度 p(x1=0.3):", marginal_pdf1(0.3))
边缘密度 p(x2=0.6): 0.36825997
边缘密度 p(x1=0.3): 0.38137165

conditional probability density function

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述now,we try to solve the problem that how to calculate the conditional probability density of three dimensions random vectors which conform to joint gauss distribution .
在这里插入图片描述

the joint gauss distribution was defined as following Python codes:

import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal# 定义均值和协方差矩阵 (3维)
mu = jnp.array([1.0, -1.0, 0.5])  # [μ_X, μ_Y, μ_W]# 协方差矩阵 (3x3)
Sigma = jnp.array([[2.0, 0.5, -0.3],   # X的方差和协方差[0.5, 1.0, 0.2],     # Y的方差和协方差[-0.3, 0.2, 1.5]     # W的方差和协方差
])# 检查协方差矩阵是否正定
assert jnp.all(jnp.linalg.eigvals(Sigma) > 0), "协方差矩阵必须正定!"

the conditional probability distribution of X when Y=y and W=w is as follows:

# 分块矩阵索引
# Z = [X, Y, W], 条件分布 P(X | Y=y, W=w)
# 分割协方差矩阵
Sigma_XX = Sigma[0, 0].reshape(1, 1)          # Var(X)
Sigma_XY = Sigma[0, 1:].reshape(1, 2)          # Cov(X, [Y, W])
Sigma_YY = Sigma[1:, 1:]                       # Var([Y, W])
Sigma_YX = Sigma_XY.T# 给定观测值
y_observed = jnp.array([-0.5, 1.0])  # [y, w]# 计算条件均值和协方差
Sigma_YY_inv = jnp.linalg.inv(Sigma_YY)
mu_X_given_YW = mu[0] + Sigma_XY @ Sigma_YY_inv @ (y_observed - mu[1:])
Sigma_X_given_YW = Sigma_XX - Sigma_XY @ Sigma_YY_inv @ Sigma_YXprint("Conditional mean (X | Y, W):", mu_X_given_YW)
print("Conditional variance (X | Y, W):", Sigma_X_given_YW)

jax.numpy

jax.numpy.linspace

jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)

return a set of numbers separated with the fixed interval evenly.

import jax
import jax.numpy as jnp
print(jnp.linspace(11,55,5))
[11. 22. 33. 44. 55.]
import jax
import jax.numpy as jnp
print(jnp.linspace(11,55,5,retstep=True))
(Array([11., 22., 33., 44., 55.], dtype=float32), Array(11., dtype=float32))

the function can also use for generating multiple dimensions array.

import jax
import jax.numpy as jnp
print(jnp.linspace(jnp.array([1,11]),jnp.array([5,55]),5))
[[ 1. 11.][ 2. 22.][ 3. 33.][ 4. 44.][ 5. 55.]]

jax.numpy.arange

the function make also a sequence consisted of number and they are separated by equal interval.the function is similar as jax.numpy.linspace.but the important difference is that jax.numpy.arange runing depend on the step which means interval and the jax.numpy.linspace is applied with the argument num which represents the number of numbers.

jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)
[11 22 33 44]

there are a difference which is easily ignored usually that the stop at jax.numpy.arange means the end but excusive itself.by the way, jax.numpy.arange’s arguments must be scalars.
the following code will get a error when it runing.

print(jnp.arange(jnp.array([1,11]),jnp.array([5,55]),11))

jax.typing

  1. the function annotations applied for static type checking maybe become a integral python coding standard.
  2. jax.Array is the base class represented array.
  3. to annotate in python project.
  • Level 1: Annotations as documentation
def f(x: jax.Array) -> jax.Array:  # type annotations are valid for traced and non-traced types.return x
  • Level 2: Annotations for intelligent autocomplete
    the many modern IDEs such as vscode make use of the type annotations in intelligent code completion systems.
  • Level 3: Annotations for static type-checking
  1. the package development with JAX must abide by two python type checking facilities including pytype developed by google , and mypy which known as the most popular static type checking tools.And beyond that, JAX will face chanllenges such as array duck-typing,transformations and decorators,array annotation lack of granularity and imprecise APIs inherited from NumPy.
  2. JAX provided that static type annotations and runtime instance checks for duck-typed objects.
  • Static type annotations
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnpArrayAnnotation = Union[Array, Tracer]@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:assert isinstance(x, (Array, Tracer))  # Explicit checkreturn x * 2x = jnp.array([1.0, 2.0, 3.0])
result = f(x)
print(result)  # [2. 4. 6.] (jax.Array)@jit
def g(x):return f(x)  # `x` is a Tracer here!print(g(x))      # Same output, but internally tracedfrom jax import graddf_dx = grad(lambda x: f(x).sum())  # Works with tracers
print(df_dx(x))  # [2. 2. 2.] (gradient of x*2)f("invalid_input")
f(234)
[2. 4. 6.]
[2. 4. 6.]
[2. 2. 2.]
Traceback (most recent call last):File "e:\learn\learnpy\l2.py", line 29, in <module>f("invalid_input")
TypeError: Error interpreting argument to <function f at 0x0000018EEFD999E0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
  • Runtime instance checks
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnpArrayInstance = Union[Array, Tracer]@jit
def f(x):return isinstance(x, ArrayInstance)x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

Diagonal Matrix

  1. one dimension array which its elements be placed in the diagonal matrix can be convert to the diagonal matrix through jnp.diag.
import jax.numpy as jnp# 从一维数组创建对角矩阵
vec = jnp.array([1, 2, 3])
diag_matrix = jnp.diag(vec)  # 对角线为 [1, 2, 3],其余为 0
print(diag_matrix)

of course,you can change reversely such as grabing the elements from a diagonal matrix.

matrix = jnp.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
extracted_vec = jnp.diag(matrix)  # 返回 [1, 2, 3]

defining a unit matrix can apply jnp.eye.

import jax.numpy as jnp# 从一维数组创建对角矩阵
vec_I = jnp.eye(3)
print(vec_I)

now ,we write a bit of python code for illustrating AI=IA=AAI=IA=AAI=IA=A.

import jax
import jax.numpy as jnpkey = jax.random.PRNGKey(42)  # 随机种子
A = jax.random.uniform(key, shape=(5, 3), minval=0.0, maxval=10.0)
print(A)
vec_I = jnp.eye(3)
print(vec_I)
print(A @ vec_I)
jnp.dot(A, vec_I)
[[4.8870955 6.7979717 6.162715 ][5.610161  4.506446  5.858659 ][0.7480943 7.7513337 6.9895926][8.1863365 3.503052  8.7282   ][9.258814  8.601307  4.775541 ]]
[[1. 0. 0.][0. 1. 0.][0. 0. 1.]]
[[4.8870955 6.7979717 6.162715 ][5.610161  4.506446  5.858659 ][0.7480943 7.7513337 6.9895926][8.1863365 3.503052  8.7282   ][9.258814  8.601307  4.775541 ]]

to create a random matrix is as follows:

key = jax.random.PRNGKey(0)
random_matrix = jax.random.normal(key, (3, 3))

you can generate a matrix from list ,for example

matrix = jnp.array([[1, 2], [3, 4]])

references

  1. https://docs.jax.dev/
  2. deepseek
http://www.dtcms.com/a/413075.html

相关文章:

  • 盐城做网站找哪家好建设好网站能赚到钱吗
  • 有哪些网站可以做青旅义工wordpress加目录
  • 企业网站模板演示德州手机网站建设报价
  • 鄱阳县精准扶贫旅游网站建设目的用ps个人网站怎么做
  • 做网站需要公司么业之峰家装公司地址
  • 快速收录网站微信开发者工具有什么作用
  • 查钓鱼网站seo优化排名百度教程
  • 虹口专业做网站网站建设预算描述
  • 模板网站的坏处铁岭做网站的公司
  • 水果电商网站开发方案做效果图的网站有哪些软件
  • 做游戏网站打鱼矿泉水瓶50种手工制作
  • 网站建设销售培训网站建设 超薄网络
  • 网站上怎么做动画广告视频玉环建设规划局网站
  • 校园二手市场网站建设最专业的房地产网站建设
  • 网站制作内联框unity 做网站
  • 建立网站顺序网站建设公司宣传语
  • 35互联网站建设甘肃省住房和城乡建设厅注册中心网站
  • 徐州哪家做网站好gif表情包制作网站
  • 标志空间网站做网站赚钱一般做什么
  • php网站开发要学什么软件呼叫中心系统平台
  • 网页设计作业动漫网页英文外链seo兼职
  • 网站建设方案书 备案wordpress+php要求
  • 【Nginx开荒攻略】Nginx静态文件服务:从MIME类型到缓存优化的完整指南
  • 上传的网站怎么打开洛可可设计公司怎么样
  • 秦皇岛城乡住房建设厅网站申请域名免费
  • 教育兼职网站开发青岛北京网站建设价格
  • 从机械齿轮到硅基大脑:计算机起源探秘(2)
  • 个人网站设计模板素材企业网络构建
  • 团购网站怎么推广专业网站建设课程
  • wordpress本站导航在哪里做cpa网站