jax study notes[19]
文章目录
- simple iteration method
- references
simple iteration method
- the simple iteration method (Fixed-point iteration method) can be used to computate the roots of equation through repeated approaching the correct root more and more closely.
- Fixed-point iteration method intends to find a point
x'
which satisfies f(x′)=x′f(x')=x'f(x′)=x′.for exmaple,f(x)=x3+2x2−9x+5=0f(x)=x^3+2x^2-9x+5=0f(x)=x3+2x2−9x+5=0 is equivalent to x=x3+2x2+59x=\frac {x^3+2x^2+5} {9}x=9x3+2x2+5.
- In the first place an initial root x1′x_1'x1′ has be worked out. And in the second place, the x1′x_1'x1′ is substituted into the function
f(x)
to get f(x1′)f(x_1')f(x1′) as new value x2′x_2'x2′.until the nth attempt, the final value of xn′x_n'xn′ is almost approximately equal to f(xn′)f(x_n')f(xn′) . - the basic code can be as follows.
import jax
import jax.numpy as jnp
from jax import jitdef fixed_point_iteration(g, x0, tol=1e-6, max_iter=1000):"""不动点迭代法求解方程 g(x) = x参数:g: 迭代函数x0: 初始猜测值tol: 容差 (默认 1e-6)max_iter: 最大迭代次数 (默认 1000)返回:近似不动点"""def cond_fn(val):x_prev, x, i = valreturn (jnp.abs(x - x_prev) > tol) & (i < max_iter)def body_fn(val):x_prev, x, i = valreturn x, g(x), i+1# 使用 while_loop 进行迭代_, result, _ = jax.lax.while_loop(cond_fn,body_fn,(x0, g(x0), 0))return result# 使用 JIT 编译加速
fixed_point_iteration_jit = jit(fixed_point_iteration, static_argnums=(0,))# 示例使用
if __name__ == "__main__":# 定义迭代函数 (例如求解 x = cos(x))def g(x):return jnp.cos(x)# 调用不动点迭代法x_star = fixed_point_iteration_jit(g, 1.0)print(f"不动点近似值为: {x_star}")print(f"验证 g(x) - x = {g(x_star) - x_star}")
the above code ignore a different situation that the iteration is not convergent.
import jax
import jax.numpy as jnp
from jax import jit
def fixed_point_iteration_advanced(g, x0, tol=1e-6, max_iter=1000):"""带收敛诊断的不动点迭代法返回:(近似不动点, 是否收敛, 迭代次数)"""def cond_fn(val):x_prev, x, i, converged = valreturn (~converged) & (i < max_iter)def body_fn(val):x_prev, x, i, _ = valx_new = g(x)converged = jnp.abs(x_new - x) < tolreturn x, x_new, i+1, converged# 使用 while_loop 进行迭代_, result, iterations, converged = jax.lax.while_loop(cond_fn,body_fn,(x0, g(x0), 0, False))return result, converged, iterations# 使用 JIT 编译加速
fixed_point_iteration_advanced_jit = jit(fixed_point_iteration_advanced, static_argnums=(0,))# 示例使用
if __name__ == "__main__":# 定义迭代函数 (例如求解 x = e^{-x})def g(x):return jnp.exp(-x)# 调用高级版本x_star, converged, iters = fixed_point_iteration_advanced_jit(g, 0.5)print(f"不动点近似值为: {x_star}")print(f"是否收敛: {converged}")print(f"迭代次数: {iters}")
references
- deepseek
- 《数值计算方法》