unlike numpy, jax.numpy doesn’t promote float64 for efficiency and compatibility (TPUs), it does also support bfloat16
use jax.config.update('jax_numpy_dtype_promotion', 'strict') to disable implicit type promotion (still allows python scalars type promotion for convenience)
wrap input args → record all operations
ns → produce a jaxpr using the tracer records
| with jit | without jit |
|---|---|
| jaxpr operations passed all at once as a graph / sequence | operations are passed one by one as being interpreted |
traces a function during the first call, wrap objects & functions in Traced<…> and keep track of shapes, types (like graph guards in torch compile) to produce operations that can be passed to XLA to compute
jax.make_jaxpr can be used to extract the produced sequence of operations
static operations are evaluated at compile-time in python
traced operations are evaluated at run-tim in XLA
arguments to be used in control flow should be flagged as static using static_argnums=(static_arg_pos, )
or static_argname=("arg_names",)
# @partial(jit, static_argnums=(1,))
def f(x, y):
return x*2 if y else x
double = True
jax.jit(f, static_argnums=(1,))(x, double)
jax.jit decorator cache a compiled version of the function using it’s hash code, that as long as the traced inputs have the same type, shape will be used, else func will be recompiled and cached
when a jitted function is decorated with jax.jit the cached version of the inner function will be used
avoid producing a new hash each call
arguments passed to a jitted function can be “dnoated” i.e we use the same device memory allocated for the input values of that argument to store the output / updated value (post func)
using jax.eval_shape we can trace inputs to figure out shapes (useful to avoid compiling twice when one of the arguments will start as None then will be an array)
#jax.gradjax.grad trace and transform a given function f and return a function that given a set of inputs, differentiate f w.r.t to the first one
by default a jax.grad transformed function compute the gradient of it’s first argument (jnp.array or PyTree) by tracing all operations inside that function
jax.value_and_grad is similar to grad, but evaluate the function & it’s grad
y, dfdx = jax.value_and_grad(f)(x)
#jax.jacobianjax.jacobian is similar to jax.grad but operate on vector valued functions
#jax.jvpforward mode auto-diff to compute the Jacobian vector product
i.e: directional derivative along an input vector v
v is a tangent vector in input space
y, jvp = jax.jvp(f, (x,), (v,))
$$ y = f(x), \quad \mathrm{jvp}(\mathbf{v}) = \mathrm{J}_f(X).\mathbf{v} \\
$$
$$ \mathrm{J}_f(X) = \begin{bmatrix}
\frac{\partial{f_1}}{\partial x_1} \frac{\partial{f_1}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\
\frac{\partial{f_2}}{\partial x_1} \frac{\partial{f_2}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\
\vdots \space\space\ \vdots & \ddots & \vdots \\
\frac{\partial{f_m}}{\partial x_1} \frac{\partial{f_m}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \end{bmatrix} $$
#jax.vjpforward mode vector Jacobian product: computes $\mathbf{J}.\mathbf{v}$: directional derivative along the input vector $\mathbf{v}$
c is a cotangent vector in output space
y, pullback = jax.vjp(f, x)
$$ y = f(x), \\ \mathrm{vjp}(\mathbf{c}) = \mathbf{c}^\top . \mathrm{J}_f(X) $$
#jax.jacobian#jax.jacfwdcompute the full Jacobian matrix in forward mode diff: efficient when input dim is small
#jax.jacrevcompute the full Jacobian matrix in forward mode diff: efficient when output dim is small
#jax.jacobiana high level wrapper that dynamically chooses jax.jacrev or jax.jacfwd depending on input, output shapes
#jax.linearizeevaluate a function fat point / input vector v and return the output and a linear approximate of f
used for efficient compute
y, f_lin = jax.linearize(f, x)
#jax.hessiancomputes the hessian matrix
import jax
def f(x):
return x[0]**2 + x[0]*x[1] + x[1]**3
H = jax.jit(jax.hessian(f))(x)
H @ v
alternatives:
vmap traces inputs arguments just like jax.jit and adds a batch dimension at axis 0 by default, batch dimension position can be configured for inputs and outputs with in_axes & out_axes
(if set to None, the corresponding argument won’t be vectorized)
xs = jnp.stack([x, x]) # (b, n)
wst = jnp.stack([w, w]).T # (n, b)
jax.vmap(convolve1d, in_axis=[1, None], out_axis=0)(xst, w) # (b, n)
note: in_axes & out_axes in vmap … can be Pytrees with the same structure on the corresponding input & output Pytrees
and a whole branch of a Pytree can be set to the same value (no need to specify the whole structure + the same value at each leaf)
vectorized map
parallel map: SPMD
Jaxpr : Jax exPRession (torch IR eq)
is a representation of a sequence of primitive operations encoding the shape & type of each operation, args, inputs, outputs
a functional freak solution for tossing states around
a tree is a collection of tensors such as a model’s params, an optimizer state, …
a leaf’s key path is a list of keys of length = depth of leaf
jax.tree_util.keystr() provide a reader-friendly path key repr with a preset of keys for predefined nodes and using __str__ in custom ones
| --- | --- | --- |
#jax.tree.structure & #jax.tree.leavesskims a collection and extract it’s structure / leaves
PyTreeDef([*, *, {“key”: *, “keys”: {”key1”: *, “key2”: *}}])
[num1, num2, Array([…], dtype=…), Array([…], dtype=…), Array([…], dtype=…)]