Array

Attributes

Types promotion

unlike numpy, jax.numpy doesn’t promote float64 for efficiency and compatibility (TPUs), it does also support bfloat16

use jax.config.update('jax_numpy_dtype_promotion', 'strict') to disable implicit type promotion (still allows python scalars type promotion for convenience)

Transformations

Jit

wrap input args → record all operations

ns → produce a jaxpr using the tracer records

with jit without jit
jaxpr operations passed all at once as a graph / sequence operations are passed one by one as being interpreted

traces a function during the first call, wrap objects & functions in Traced<…> and keep track of shapes, types (like graph guards in torch compile) to produce operations that can be passed to XLA to compute

jax.make_jaxpr can be used to extract the produced sequence of operations

static args

static operations are evaluated at compile-time in python

traced operations are evaluated at run-tim in XLA

arguments to be used in control flow should be flagged as static using static_argnums=(static_arg_pos, )

or static_argname=("arg_names",)

# @partial(jit, static_argnums=(1,))
def f(x, y):
	return x*2 if y else x

double = True
jax.jit(f, static_argnums=(1,))(x, double)

compiled func caching

jax.jit decorator cache a compiled version of the function using it’s hash code, that as long as the traced inputs have the same type, shape will be used, else func will be recompiled and cached

when a jitted function is decorated with jax.jit the cached version of the inner function will be used

avoid producing a new hash each call

implicit, unwanted behaviors

notes

Auto-diff

#jax.grad

jax.grad trace and transform a given function f and return a function that given a set of inputs, differentiate f w.r.t to the first one

by default a jax.grad transformed function compute the gradient of it’s first argument (jnp.array or PyTree) by tracing all operations inside that function

jax.value_and_grad is similar to grad, but evaluate the function & it’s grad

y, dfdx = jax.value_and_grad(f)(x)

#jax.jacobian

jax.jacobian is similar to jax.grad but operate on vector valued functions

#jax.jvp

forward mode auto-diff to compute the Jacobian vector product

i.e: directional derivative along an input vector v

v is a tangent vector in input space

y, jvp = jax.jvp(f, (x,), (v,))

$$ y = f(x), \quad \mathrm{jvp}(\mathbf{v}) = \mathrm{J}_f(X).\mathbf{v} \\

$$

$$ \mathrm{J}_f(X) = \begin{bmatrix}

\frac{\partial{f_1}}{\partial x_1} \frac{\partial{f_1}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\

\frac{\partial{f_2}}{\partial x_1} \frac{\partial{f_2}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\

\vdots \space\space\ \vdots & \ddots & \vdots \\

\frac{\partial{f_m}}{\partial x_1} \frac{\partial{f_m}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \end{bmatrix} $$

#jax.vjp

forward mode vector Jacobian product: computes $\mathbf{J}.\mathbf{v}$: directional derivative along the input vector $\mathbf{v}$

c is a cotangent vector in output space

y, pullback = jax.vjp(f, x)

$$ y = f(x), \\ \mathrm{vjp}(\mathbf{c}) = \mathbf{c}^\top . \mathrm{J}_f(X) $$

#jax.jacobian

#jax.jacfwd

compute the full Jacobian matrix in forward mode diff: efficient when input dim is small

#jax.jacrev

compute the full Jacobian matrix in forward mode diff: efficient when output dim is small

#jax.jacobian

a high level wrapper that dynamically chooses jax.jacrev or jax.jacfwd depending on input, output shapes

#jax.linearize

evaluate a function fat point / input vector v and return the output and a linear approximate of f

used for efficient compute

y, f_lin = jax.linearize(f, x)

#jax.hessian

computes the hessian matrix

import jax
def f(x):
  return x[0]**2 + x[0]*x[1] + x[1]**3

H = jax.jit(jax.hessian(f))(x)
H @ v

alternatives:

using jax.jacfwd and jax.jacrev

from jax import jacfwd, jacrev

def hessian(f):
    return jax.jit(jacfwd(jacrev(f)))

H = hessian(f)(x)
Hv = H @ v

using jax.linearize and jax.grad

import jax

y, f_lin = jax.linearize(
	jax.grad(f), x
)
v = jnp.array([1., 1.])
Hv = jax.jit(f_lin)(v)

vmap & pmap

vmap traces inputs arguments just like jax.jit and adds a batch dimension at axis 0 by default, batch dimension position can be configured for inputs and outputs with in_axes & out_axes

(if set to None, the corresponding argument won’t be vectorized)