Learn Optax#

Quick Start#

Let’s use optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.

We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as “odd” or “even” using 1-hot encoding (i.e. [1, 0] means odd [0, 1] means even).

import random
from typing import Tuple

import optax
import jax.numpy as jnp
import jax
import numpy as np

RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.

There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.

Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian \(\mathcal{N}(0,1)\) distribution.

initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),

def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x

def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

We will use optax.adam to compute the parameter updates from their gradients on each optimizer step.

Note that since optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
step 0, loss: 3.610544204711914
step 100, loss: 0.4551412761211395
step 200, loss: 0.01912495493888855
step 300, loss: 0.03166782110929489
step 400, loss: 0.013636551797389984
step 500, loss: 0.0016018962487578392
step 600, loss: 0.0031061682384461164
step 700, loss: 0.002592239063233137
step 800, loss: 0.0014123425353318453
step 900, loss: 0.01507741678506136

We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network

Weight Decay, Schedules and Clipping#

Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as optax.adam and optax.clip.

In the following, we will use Adam with weight decay (optax.adamw), a cosine learning rate schedule (with warmup) and also gradient clipping.

schedule = optax.warmup_cosine_decay_schedule(

optimizer = optax.chain(

params = fit(initial_params, optimizer)
step 0, loss: 3.610544204711914
step 100, loss: 0.0
step 200, loss: 0.0
step 300, loss: 0.0
step 400, loss: 0.0
step 500, loss: 0.0
step 600, loss: 0.0
step 700, loss: 0.0
step 800, loss: 0.0
step 900, loss: 0.0