Array

Attributes

Types promotion

unlike numpy, jax.numpy doesn’t promote float64 for efficiency and compatibility (TPUs), it does also support bfloat16

use jax.config.update('jax_numpy_dtype_promotion', 'strict') to disable implicit type promotion (still allows python scalars type promotion for convenience)

Transformations

Jit

wrap input args → record all operations

ns → produce a jaxpr using the tracer records

with jit without jit
jaxpr operations passed all at once as a graph / sequence operations are passed one by one as being interpreted

traces a function during the first call, wrap objects & functions in Traced<…> and keep track of shapes, types (like graph guards in torch compile) to produce operations that can be passed to XLA to compute

jax.make_jaxpr can be used to extract the produced sequence of operations

static args

static operations are evaluated at compile-time in python

traced operations are evaluated at run-tim in XLA

arguments to be used in control flow should be flagged as static using static_argnums=(static_arg_pos, )

or static_argname=("arg_names",)

# @partial(jit, static_argnums=(1,))
def f(x, y):
	return x*2 if y else x

double = True
jax.jit(f, static_argnums=(1,))(x, double)

compiled func caching

jax.jit decorator cache a compiled version of the function using it’s hash code, that as long as the traced inputs have the same type, shape will be used, else func will be recompiled and cached

when a jitted function is decorated with jax.jit the cached version of the inner function will be used

avoid producing a new hash each call

implicit, unwanted behaviors

notes

donated args

arguments passed to a jitted function can be “dnoated” i.e we use the same device memory allocated for the input values of that argument to store the output / updated value (post func)

Tracing shapes with no comp

using jax.eval_shape we can trace inputs to figure out shapes (useful to avoid compiling twice when one of the arguments will start as None then will be an array)

Auto-diff

#jax.grad

jax.grad trace and transform a given function f and return a function that given a set of inputs, differentiate f w.r.t to the first one

by default a jax.grad transformed function compute the gradient of it’s first argument (jnp.array or PyTree) by tracing all operations inside that function

jax.value_and_grad is similar to grad, but evaluate the function & it’s grad

y, dfdx = jax.value_and_grad(f)(x)

#jax.jacobian

jax.jacobian is similar to jax.grad but operate on vector valued functions

#jax.jvp

forward mode auto-diff to compute the Jacobian vector product

i.e: directional derivative along an input vector v

v is a tangent vector in input space

y, jvp = jax.jvp(f, (x,), (v,))

$$ y = f(x), \quad \mathrm{jvp}(\mathbf{v}) = \mathrm{J}_f(X).\mathbf{v} \\

$$

$$ \mathrm{J}_f(X) = \begin{bmatrix}

\frac{\partial{f_1}}{\partial x_1} \frac{\partial{f_1}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\

\frac{\partial{f_2}}{\partial x_1} \frac{\partial{f_2}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\

\vdots \space\space\ \vdots & \ddots & \vdots \\

\frac{\partial{f_m}}{\partial x_1} \frac{\partial{f_m}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \end{bmatrix} $$

#jax.vjp

forward mode vector Jacobian product: computes $\mathbf{J}.\mathbf{v}$: directional derivative along the input vector $\mathbf{v}$

c is a cotangent vector in output space

y, pullback = jax.vjp(f, x)

$$ y = f(x), \\ \mathrm{vjp}(\mathbf{c}) = \mathbf{c}^\top . \mathrm{J}_f(X) $$

#jax.jacobian

#jax.jacfwd

compute the full Jacobian matrix in forward mode diff: efficient when input dim is small

#jax.jacrev

compute the full Jacobian matrix in forward mode diff: efficient when output dim is small

#jax.jacobian

a high level wrapper that dynamically chooses jax.jacrev or jax.jacfwd depending on input, output shapes

#jax.linearize

evaluate a function fat point / input vector v and return the output and a linear approximate of f

used for efficient compute

y, f_lin = jax.linearize(f, x)

#jax.hessian

computes the hessian matrix

import jax
def f(x):
  return x[0]**2 + x[0]*x[1] + x[1]**3

H = jax.jit(jax.hessian(f))(x)
H @ v

alternatives:

vmap & pmap

vmap traces inputs arguments just like jax.jit and adds a batch dimension at axis 0 by default, batch dimension position can be configured for inputs and outputs with in_axes & out_axes

(if set to None, the corresponding argument won’t be vectorized)

xs = jnp.stack([x, x])    # (b, n)
wst = jnp.stack([w, w]).T # (n, b)

jax.vmap(convolve1d, in_axis=[1, None], out_axis=0)(xst, w) # (b, n)

note: in_axes & out_axes in vmap … can be Pytrees with the same structure on the corresponding input & output Pytrees

and a whole branch of a Pytree can be set to the same value (no need to specify the whole structure + the same value at each leaf)

vectorized map

parallel map: SPMD

Jaxpr

Jaxpr : Jax exPRession (torch IR eq)

is a representation of a sequence of primitive operations encoding the shape & type of each operation, args, inputs, outputs

PyTree

a functional freak solution for tossing states around

a tree is a collection of tensors such as a model’s params, an optimizer state, …

Key paths

a leaf’s key path is a list of keys of length = depth of leaf

jax.tree_util.keystr() provide a reader-friendly path key repr with a preset of keys for predefined nodes and using __str__ in custom ones

| --- | --- | --- |

#jax.tree.structure & #jax.tree.leaves

skims a collection and extract it’s structure / leaves

PyTreeDef([*, *, {“key”: *, “keys”: {”key1”: *, “key2”: *}}])
[num1, num2, Array([…], dtype=…), Array([…], dtype=…), Array([…], dtype=…)]