Learn Optax

Quick Start

Let’s use optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.

We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as “odd” or “even” using 1-hot encoding (i.e. [1, 0] means odd [0, 1] means even).

import random
from typing import Tuple

import optax
import jax.numpy as jnp
import jax
import numpy as np

BATCH_SIZE = 5
NUM_TRAIN_STEPS = 1_000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.

There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.

Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian \(\mathcal{N}(0,1)\) distribution.

initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

We will use optax.adam to compute the parameter updates from their gradients on each optimizer step.

Note that since optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
step 0, loss: 7.382710933685303
step 100, loss: 0.5551559925079346
step 200, loss: 0.028613239526748657
step 300, loss: 0.027356812730431557
step 400, loss: 0.03194188326597214
step 500, loss: 0.015071901492774487
step 600, loss: 0.005017134360969067
step 700, loss: 0.002122548408806324
step 800, loss: 0.04145048186182976
step 900, loss: 0.0007623251876793802

We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network

Weight Decay, Schedules and Clipping

Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as optax.adam and optax.clip.

In the following, we will use Adam with weight decay (optax.adamw), a cosine learning rate schedule (with warmup) and also gradient clipping.

schedule = optax.warmup_cosine_decay_schedule(
  init_value=0.0,
  peak_value=1.0,
  warmup_steps=50,
  decay_steps=1_000,
  end_value=0.0,
)

optimizer = optax.chain(
  optax.clip(1.0),
  optax.adamw(learning_rate=schedule),
)

params = fit(initial_params, optimizer)
step 0, loss: 7.382710933685303
step 100, loss: 0.0
step 200, loss: 4.671875386903726e-20
step 300, loss: 9.647353172302246
step 400, loss: 0.0
step 500, loss: 0.0
step 600, loss: 0.0
step 700, loss: 0.0
step 800, loss: 0.0
step 900, loss: 0.0