daily notes[55]
文章目录
- the maxtrix manipulation
- SPMD
- references
the maxtrix manipulation
the following matrix will be demonstrated for manipulation.
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
- matrix addition
sum = a + b
- matrix multiplication
elementwise = a * b
- Matrix dot product
dot = jnp.dot(a, b)
dot = a @ b
- matrix transposition
transpose = a.T
- inverse matrix
inv = jnp.linalg.inv(a)
- determinant
det = jnp.linalg.det(a)
SPMD
- Single-Program Multi-Data (SPMD) was supported by JAX.
- SPMD stand by the same compuation.JAX has not constraint that all datas divide into same devices.
- some datas can be put into different devices such as GPU,CPU and TPUS for manipuation.
- jax.jit
jax.jit() used to improve the efficiency in fulfiling a task through putting Just In Time (JIT) compilation into running a JAX python function.
the achievement of jax.jit() accord to a primitive principle that each function will be decomposed into a sequence of primitive operations, every element represents a fundamental unit of computation.
jax.jit(fun, /, *, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)
references
- deepseek
- https://docs.jax.dev/