🚀 Getting started#

Open in Colab

Optax is a simple optimization library for JAX. The main object is the GradientTransformation, which can be chained with other transformations to obtain the final update operation and the optimizer state. Optax also contains some simple loss functions and utilities to help you write the full optimization steps. This notebook walks you through a few examples on how to use Optax.

Example: Fitting a Linear Model#

Begin by importing the necessary packages:

import jax.numpy as jnp
import jax
import optax
import functools

In this example, we begin by setting up a simple linear model and a loss function. You can use any other library, such as haiku or Flax to construct your networks. Here, we keep it simple and write it ourselves. The loss function (L2 Loss) comes from Optax’s losses via l2_loss.

@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
  return jnp.dot(params, x)

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = jnp.mean(optax.l2_loss(y_pred, y))
  return loss

Here we generate data under a known linear model (with target_params=0.5):

key = jax.random.PRNGKey(42)
target_params = 0.5

# Generate some data.
xs = jax.random.normal(key, (16, 2))
ys = jnp.sum(xs * target_params, axis=-1)

Basic usage of Optax#

Optax contains implementations of many popular optimizers that can be used very simply. For example, the gradient transform for the Adam optimizer is available at optax.adam. For now, let’s start by calling the GradientTransformation object for Adam the optimizer. We then initialize the optimizer state using the init function and params of the network.

start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)

# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])
opt_state = optimizer.init(params)

Next we write the update loop. The GradientTransformation object contains an update function that takes in the current optimizer state and gradients and returns the updates that need to be applied to the parameters: updates, new_opt_state = optimizer.update(grads, opt_state).

Optax comes with a few simple update rules that apply the updates from the gradient transforms to the current parameters to return new ones: new_params = optax.apply_updates(params, updates).

# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'

Custom optimizers#

Optax makes it easy to create custom optimizers by chaining gradient transforms. For example, this creates an optimizer based on Adam. Note that the scaling is -learning_rate which is an important detail since apply_updates is additive.

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=start_learning_rate,
    transition_steps=1000,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)
# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])  # Recall target_params=0.5.
opt_state = gradient_transform.init(params)

# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = gradient_transform.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'

Advanced usage of Optax#

Modifying hyperparameters of optimizers in a schedule.#

In some scenarios, changing the hyperparameters (other than the learning rate) of an optimizer can be useful to ensure training reliability. We can do this easily by using inject_hyperparams. For example, this piece of code decays the max_norm of the clip_by_global_norm gradient transform as training progresses:

decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)(
    max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99))

opt_state = decaying_global_norm_tx.init(None)
assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0'

for _ in range(100):
  _, opt_state = decaying_global_norm_tx.update(None, opt_state)

assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0'

Example: Fitting a MLP#

Let’s use Optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.

We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as “odd” or “even” using 1-hot encoding (i.e. [1, 0] means odd [0, 1] means even).

import optax
import jax.numpy as jnp
import jax
import numpy as np

BATCH_SIZE = 5
NUM_TRAIN_STEPS = 1_000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)

We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.

There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.

Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian \(\mathcal{N}(0,1)\) distribution.

initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

We will use optax.adam to compute the parameter updates from their gradients on each optimizer step.

Note that since Optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
step 0, loss: 14.34672737121582
step 100, loss: 0.3440847098827362
step 200, loss: 0.024329861626029015
step 300, loss: 0.06388916075229645
step 400, loss: 0.009651201777160168
step 500, loss: 0.009742096066474915
step 600, loss: 0.006260392721742392
step 700, loss: 0.0019267834722995758
step 800, loss: 0.0008281346526928246
step 900, loss: 0.00033883278956636786

We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network.

Weight Decay, Schedules and Clipping#

Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as optax.adam and optax.clip.

In the following, we will use Adam with weight decay (optax.adamw), a cosine learning rate schedule (with warmup) and also gradient clipping.

schedule = optax.warmup_cosine_decay_schedule(
  init_value=0.0,
  peak_value=1.0,
  warmup_steps=50,
  decay_steps=1_000,
  end_value=0.0,
)

optimizer = optax.chain(
  optax.clip(1.0),
  optax.adamw(learning_rate=schedule),
)

params = fit(initial_params, optimizer)
step 0, loss: 14.34672737121582
step 100, loss: 2.368289418741565e-11
step 200, loss: 3.444157112286739e-11
step 300, loss: 1.9778159665584383e-10
step 400, loss: 1.2797478354809044e-11
step 500, loss: 2.3227363088462738e-10
step 600, loss: 2.2316104519859437e-10
step 700, loss: 8.796226858009959e-09
step 800, loss: 2.651706436894441e-13
step 900, loss: 6.388567626303132e-12

Components#

We refer to the docs for a detailed list of available Optax components. Here, we highlight the main categories of building blocks provided by Optax.

Gradient Transformations (transform.py)#

One of the key building blocks of Optax is a GradientTransformation. Each transformation is defined by two functions:

state = init(params)

grads, state = update(grads, state, params=None)

The init function initializes a (possibly empty) set of statistics (aka state) and the update function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.

For example:

tx = optax.scale_by_rms()
state = tx.init(params)  # init stats
grads = jax.grad(loss)(params, TRAINING_DATA, LABELS)
updates, state = tx.update(grads, state, params)  # transform & update stats.

Composing Gradient Transformations (combine.py)#

The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.

For instance, chain combines them sequentially, and returns a new GradientTransformation that applies several transformations in sequence.

For example:

max_norm = 100.
learning_rate = 1e-3

my_optimiser = optax.chain(
    optax.clip_by_global_norm(max_norm),
    optax.scale_by_adam(eps=1e-4),
    optax.scale(-learning_rate))

Wrapping Gradient Transformations (wrappers.py)#

Optax also provides several wrappers that take a GradientTransformation as input and return a new GradientTransformation that modifies the behaviour of the inner transformation in a specific way.

For instance, the flatten wrapper flattens gradients into a single large vector before applying the inner GradientTransformation. The transformed updates are then unflattened before being returned to the user. This can be used to reduce the overhead of performing many calculations on lots of small variables, at the cost of increasing memory usage.

For example:

my_optimiser = optax.flatten(optax.adam(learning_rate))

Other examples of wrappers include accumulating gradients over multiple steps or applying the inner transformation only to specific parameters or at specific steps.

Schedules (schedule.py)#

Many popular transformations use time-dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose schedules that can be used to decay scalars as a function of a step count.

For example, you may use a polynomial_schedule (with power=1) to decay a hyper-parameter linearly over a number of steps:

schedule_fn = optax.polynomial_schedule(
    init_value=1., end_value=0., power=1, transition_steps=5)

for step_count in range(6):
  print(schedule_fn(step_count))  # [1., 0.8, 0.6, 0.4, 0.2, 0.]
1.0
0.8
0.6
0.39999998
0.19999999
0.0

Schedules can be combined with other transforms as follows.

schedule_fn = optax.polynomial_schedule(
    init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
optimiser = optax.chain(
    optax.clip_by_global_norm(max_norm),
    optax.scale_by_adam(eps=1e-4),
    optax.scale_by_schedule(schedule_fn))

Schedules can also be used in place of the learning_rate argument of a GradientTransformation as

optimiser = optax.adam(learning_rate=schedule_fn)

Applying updates (update.py)#

After transforming an update using a GradientTransformation or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using tree_map.

For convenience, we expose an apply_updates function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. tree_map(lambda p, u: p + u, params, updates).

updates, state = tx.update(grads, state, params)  # transform & update stats.
new_params = optax.apply_updates(params, updates)  # update the parameters.

Note that separating gradient transformations from the parameter update is critical to support composing a sequence of transformations (e.g. chain), as well as combining multiple updates to the same parameters (e.g. in multi-task settings where different tasks need different sets of gradient transformations).

Losses (loss.py)#

Optax provides a number of standard losses used in deep learning, such as l2_loss, softmax_cross_entropy, cosine_distance, etc.

predictions = net(TRAINING_DATA, params)
loss = optax.huber_loss(predictions, LABELS)

The losses accept batches as inputs, however, they perform no reduction across the batch dimension(s). This is trivial to do in JAX, for example:

avg_loss = jnp.mean(optax.huber_loss(predictions, LABELS))
sum_loss = jnp.sum(optax.huber_loss(predictions, LABELS))

Second Order (second_order.py)#

Computing the Hessian or Fisher information matrices for neural networks is typically intractable due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements.

Stochastic gradient estimators (stochastic_gradient_estimators.py)#

Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution’s parameters.

Unbiased estimators, such as the score function estimator (REINFORCE), pathwise estimator (reparameterization trick) or measure valued estimator, are implemented: score_function_jacobians, pathwise_jacobians and measure_valued_jacobians. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.

Stochastic gradient estimators can be combined with common control variates for variance reduction via control_variates_jacobians. For provided control variates see control_delta_method and moving_avg_baseline.

The result of a gradient estimator or control_variates_jacobians contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters or to assess gradient variance.

Example of how to use the pathwise_jacobians estimator:

mean, log_scale, rng, num_samples = 0., 1., jax.random.PRNGKey(0), 100
dist_params = [mean, log_scale]
function = lambda x: jnp.sum(x)
jacobians = optax.monte_carlo.pathwise_jacobians(
      function, dist_params,
      optax.multi_normal, rng, num_samples)

mean_grads = jnp.mean(jacobians[0], axis=0)
log_scale_grads = jnp.mean(jacobians[1], axis=0)
grads = [mean_grads, log_scale_grads]
optim = optax.adam(1e-3)
optim_state = optim.init(grads)
optim_update, optim_state = optim.update(grads, optim_state)
updated_dist_params = optax.apply_updates(dist_params, optim_update)

where optim is an Optax optimizer.