Stay humble. Stay hungry. Stay foolish.

JAX Notes

Written in

by

JAX

JAX is a ML framework. Alternative to Pytorch and Tensorflow. Doc

Key Features

  • JAX brings Autograd (jax.grad) and XLA (jax.jit) together with Numpy interface (jax.numpy).

  • JAX supports eager/interpreted mode and graph/compiled mode with jax.jit.

  • JAX supports automatic vector/process-level parallelism with jax.vmap/jax.pmap.

  • JAX supports distributed arrays and automatic data/tensor parallelism with jax.sharding.

Key APIs

# 1. `jax.jit`
# Tracing based graph capture. Enable fusion and optimize communication.
# - Frontend: Numpy interface (`jax.numpy`) -> 
#             Lax Primitives (`jax.lax.Primitive`) -> 
#             Jax Expression (`jax.JaxExpr`) -> 
# - Serialization: StableHLO Module -> 
# - Middle-end: HLO Module ->
# - Backend: CPU(LLVM/x86)/GPU(LLVM/PTX)/TPU(XLA/LLO)
# One should still use numpy interface for static tensors.

@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# 2. `jax.grad`
# Chain rule based gradient calculation.
# Calculates gradients of input variables on a function with scalar output.
# - Forward: A jax expression with lax primitives.
# - Backward: A jax expression with lax primitives. 
#             Each lax primitive defines its gradient using lax primitives.
#             The backward jax expression is built with all the gradients and the chain rule.
# - `jax.grad` and `jax.jit` can be used together or not, since lax primitives can run in both eager/graph mode.
# Related:
# - `jax.jacfwd`: Calculate Jacobian matrix. Forward mode. Column-by-column.
# - `jax.jacrev`: Calculate Jacobina matrix. Backward mode. Row-by-row.

@jax.grad
@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# 3. `jax.vmap`
# Vectorizing mapping by adding batch dimensions to function inputs/outputs.
# Example: Using vector dot product to implement matrix multiplication.
vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = jax.vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = jax.vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

# 4. `jax.pmap`
# Process mapping by adding leading batch dimensions to function inputs/outputs.
# Communications between devices are delayed until evaluation.
# Example: Using vector dot product to implement matrix multiplication. Having N dimension spreed across different process
vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = jax.vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = jax.pmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)
# Related:
# - `jax.p*` (`jax.psum`, etc): Collective ops running across different processes.

# 5. `jax.distributed`
# - Launch individual processes at each host. Initialize.
# - Each process has a set of local devices. Combined together, a set of global devices.
# - Use standard parallelism A[Is to communicate across processes.
# - Make sure all processes run the same parallel computation in same order.
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

# 6. `jax.sharding`
# Automatic tensor parallelism with user specified constrains.
# Use `jax.lax.with_sharding_constraint` to add restrictions on intermediate sharding.
# Example: 4-way batch data parallelism and 2-way model tensor parallelism
sharding = jax.sharding.PositionalSharding(jax.devices()).reshape(4, 2)
# Shards inputs along batch dimension as (4, 1). Duplicate x2.
batch = jax.device_put(batch, sharding.replicate(1))
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
# No sharding on 1st layer parameters. Duplicate x8.
W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())
# Sharding along axis=1 on 2nd layer parameters. Duplicate x4.
W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))
# Sharding along axis=0 on 3rd layer parameters. Duplicate x4.
W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())
# No sharding on 4th layer parameters. Duplicate x8.
W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
...

FLAX

Flax is a ML API based on JAX. Alternative to Pax, Praxis, Haiku on JAX or Keras on Tensorflow. Doc

Key Features

  • Modularized ML components and standard APIs flax.linen.

Key APIs

import flax.linen as nn

# `nn.Module` are `dataclasses.dataclass` subclass.
class AnyModule(nn.Module):
  features: int = 1
  ...

# 1. There are two ways to initialize all the subcomponents inside it.
# 1.1 Using `setup` to initialize attributes explicitly.
class AnyModule(nn.Module):
  def setup(self):
    self.dense = nn.Dense(self.features)

  def __call__(self, *args):
    ...

# 1.2 OR, using `__call__` to initialize subcomponents implicitly with captures.
class AnyModule(nn.Module):
  @nn.compat
  def __call__(self, *args):
    dense = nn.Dense(self.features)
    ...

module = AnyModule()

# 2. Initializes all variables.
# The module is stateless afterwards.
# Variables are nested dictionaries, mapped by categories and names.
rng = jax.random.key(0)
variables = module.init(rng, *args)

# 3. Call a unbound module with variables. 
# The module is stateless afterwards.
# Supplying variables in desired tree structure.
output = module.apply(variables, *args)

# 4. Bind variables to module.
# The module is stateful afterwards.
bound_moudle = module.bind(variables)

# 5. Call a bound module.
output = bound_module(*args)

Orbax

Orbax is a checkpoint library supporting JAX. Doc

Optax

Optax is a gradient processing and optimization library supporting JAX. Doc

Tags

Leave a comment