unlike numpy
, jax.numpy
doesn’t promote float64
for efficiency and compatibility (TPUs), it does also support bfloat16
use jax.config.update('jax_numpy_dtype_promotion', 'strict')
to disable implicit type promotion (still allows python scalars type promotion for convenience)
wrap input args → record all operations
ns → produce a jaxpr using the tracer records
with jit | without jit |
---|---|
jaxpr operations passed all at once as a graph / sequence | operations are passed one by one as being interpreted |
traces a function during the first call, wrap objects & functions in Traced<…>
and keep track of shapes, types (like graph guards in torch compile) to produce operations that can be passed to XLA
to compute
jax.make_jaxpr
can be used to extract the produced sequence of operations
static operations are evaluated at compile-time in python
traced operations are evaluated at run-tim in XLA
arguments to be used in control flow should be flagged as static using static_argnums=(static_arg_pos, )
or static_argname=("arg_names",)
# @partial(jit, static_argnums=(1,))
def f(x, y):
return x*2 if y else x
double = True
jax.jit(f, static_argnums=(1,))(x, double)
jax.jit
decorator cache a compiled version of the function using it’s hash code, that as long as the traced inputs have the same type, shape will be used, else func will be recompiled and cached
when a jitted function is decorated with jax.jit
the cached version of the inner function will be used
avoid producing a new hash each call
#jax.grad
jax.grad
trace and transform a given function f
and return a function that given a set of inputs, differentiate f
w.r.t to the first one
by default a jax.grad
transformed function compute the gradient of it’s first argument (jnp.array
or PyTree
) by tracing all operations inside that function
jax.value_and_grad
is similar to grad, but evaluate the function & it’s grad
y, dfdx = jax.value_and_grad(f)(x)
#jax.jacobian
jax.jacobian
is similar to jax.grad
but operate on vector valued functions
#jax.jvp
forward mode auto-diff to compute the Jacobian vector product
i.e: directional derivative along an input vector v
v is a tangent vector in input space
y, jvp = jax.jvp(f, (x,), (v,))
$$ y = f(x), \quad \mathrm{jvp}(\mathbf{v}) = \mathrm{J}_f(X).\mathbf{v} \\
$$
$$ \mathrm{J}_f(X) = \begin{bmatrix}
\frac{\partial{f_1}}{\partial x_1} \frac{\partial{f_1}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\
\frac{\partial{f_2}}{\partial x_1} \frac{\partial{f_2}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \\
\vdots \space\space\ \vdots & \ddots & \vdots \\
\frac{\partial{f_m}}{\partial x_1} \frac{\partial{f_m}}{\partial x_2} & \cdots & \frac{\partial{f_1}}{\partial x_n} \end{bmatrix} $$
#jax.vjp
forward mode vector Jacobian product: computes $\mathbf{J}.\mathbf{v}$: directional derivative along the input vector $\mathbf{v}$
c is a cotangent vector in output space
y, pullback = jax.vjp(f, x)
$$ y = f(x), \\ \mathrm{vjp}(\mathbf{c}) = \mathbf{c}^\top . \mathrm{J}_f(X) $$
#jax.jacobian
#jax.jacfwd
compute the full Jacobian matrix in forward mode diff: efficient when input dim is small
#jax.jacrev
compute the full Jacobian matrix in forward mode diff: efficient when output dim is small
#jax.jacobian
a high level wrapper that dynamically chooses jax.jacrev
or jax.jacfwd
depending on input, output shapes
#jax.linearize
evaluate a function f
at point / input vector v and return the output and a linear approximate of f
used for efficient compute
y, f_lin = jax.linearize(f, x)
#jax.hessian
computes the hessian matrix
import jax
def f(x):
return x[0]**2 + x[0]*x[1] + x[1]**3
H = jax.jit(jax.hessian(f))(x)
H @ v
alternatives:
using jax.jacfwd
and jax.jacrev
from jax import jacfwd, jacrev
def hessian(f):
return jax.jit(jacfwd(jacrev(f)))
H = hessian(f)(x)
Hv = H @ v
using jax.linearize
and jax.grad
import jax
y, f_lin = jax.linearize(
jax.grad(f), x
)
v = jnp.array([1., 1.])
Hv = jax.jit(f_lin)(v)
vmap
traces inputs arguments just like jax.jit
and adds a batch dimension at axis 0 by default, batch dimension position can be configured for inputs and outputs with in_axes
& out_axes
(if set to None, the corresponding argument won’t be vectorized)