Learn Optax#

Quick Start#

Let’s use optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.

We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as “odd” or “even” using 1-hot encoding (i.e. [1, 0] means odd [0, 1] means even).

import random
from typing import Tuple

import optax
import jax.numpy as jnp
import jax
import numpy as np

RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.

There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.

Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian \(\mathcal{N}(0,1)\) distribution.

initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),

def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x

def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

We will use optax.adam to compute the parameter updates from their gradients on each optimizer step.

Note that since optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
step 0, loss: 3.4678592681884766
step 100, loss: 0.6607946753501892
step 200, loss: 0.14849971234798431
step 300, loss: 0.027004409581422806
step 400, loss: 0.01529538631439209
step 500, loss: 0.02243456058204174
step 600, loss: 0.005322305951267481
step 700, loss: 0.002057740231975913
step 800, loss: 0.013683095574378967
step 900, loss: 0.0009163034264929593

We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network

Weight Decay, Schedules and Clipping#

Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as optax.adam and optax.clip.

In the following, we will use Adam with weight decay (optax.adamw), a cosine learning rate schedule (with warmup) and also gradient clipping.

schedule = optax.warmup_cosine_decay_schedule(

optimizer = optax.chain(

params = fit(initial_params, optimizer)
step 0, loss: 3.4678592681884766
step 100, loss: 1.348135492129643e-10
step 200, loss: 6.232401192656954e-16
step 300, loss: 1.6034066435688388e-11
step 400, loss: 1.2280205759218901e-19
step 500, loss: 1.287136955178958e-12
step 600, loss: 1.1874783925266974e-13
step 700, loss: 1.9421543908694494e-14
step 800, loss: 1.280755917987264e-10
step 900, loss: 2.964963615337962e-14