🚀 Getting started#
Optax is a simple optimization library for JAX. The main object is the GradientTransformation
, which can be chained with other transformations to obtain the final update operation and the optimizer state. Optax also contains some simple loss functions and utilities to help you write the full optimization steps. This notebook walks you through a few examples on how to use Optax.
Example: Fitting a Linear Model#
Begin by importing the necessary packages:
import jax.numpy as jnp
import jax
import optax
import functools
In this example, we begin by setting up a simple linear model and a loss function. You can use any other library, such as haiku or Flax to construct your networks. Here, we keep it simple and write it ourselves. The loss function (L2 Loss) comes from Optax’s losses via l2_loss
.
@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
return jnp.dot(params, x)
def compute_loss(params, x, y):
y_pred = network(params, x)
loss = jnp.mean(optax.l2_loss(y_pred, y))
return loss
Here we generate data under a known linear model (with target_params=0.5
):
key = jax.random.PRNGKey(42)
target_params = 0.5
# Generate some data.
xs = jax.random.normal(key, (16, 2))
ys = jnp.sum(xs * target_params, axis=-1)
Basic usage of Optax#
Optax contains implementations of many popular optimizers that can be used very simply. For example, the gradient transform for the Adam optimizer is available at optax.adam
. For now, let’s start by calling the GradientTransformation
object for Adam the optimizer
. We then initialize the optimizer state using the init
function and params
of the network.
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])
opt_state = optimizer.init(params)
Next we write the update loop. The GradientTransformation
object contains an update
function that takes in the current optimizer state and gradients and returns the updates
that need to be applied to the parameters: updates, new_opt_state = optimizer.update(grads, opt_state)
.
Optax comes with a few simple update rules that apply the updates from the gradient transforms to the current parameters to return new ones: new_params = optax.apply_updates(params, updates)
.
# A simple update loop.
for _ in range(1000):
grads = jax.grad(compute_loss)(params, xs, ys)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'
Custom optimizers#
Optax makes it easy to create custom optimizers by chain
ing gradient transforms. For example, this creates an optimizer based on Adam. Note that the scaling is -learning_rate
which is an important detail since apply_updates
is additive.
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
init_value=start_learning_rate,
transition_steps=1000,
decay_rate=0.99)
# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
optax.clip_by_global_norm(1.0), # Clip by the gradient by the global norm.
optax.scale_by_adam(), # Use the updates from adam.
optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler.
# Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
optax.scale(-1.0)
)
# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0]) # Recall target_params=0.5.
opt_state = gradient_transform.init(params)
# A simple update loop.
for _ in range(1000):
grads = jax.grad(compute_loss)(params, xs, ys)
updates, opt_state = gradient_transform.update(grads, opt_state)
params = optax.apply_updates(params, updates)
assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'
Advanced usage of Optax#
Modifying hyperparameters of optimizers in a schedule.#
In some scenarios, changing the hyperparameters (other than the learning rate) of an optimizer can be useful to ensure training reliability. We can do this easily by using inject_hyperparams
. For example, this piece of code decays the max_norm
of the clip_by_global_norm
gradient transform as training progresses:
decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)(
max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99))
opt_state = decaying_global_norm_tx.init(None)
assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0'
for _ in range(100):
_, opt_state = decaying_global_norm_tx.update(None, opt_state)
assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0'
Example: Fitting a MLP#
Let’s use Optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.
We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as “odd” or “even” using 1-hot encoding (i.e. [1, 0]
means odd [0, 1]
means even).
import optax
import jax.numpy as jnp
import jax
import numpy as np
BATCH_SIZE = 5
NUM_TRAIN_STEPS = 1_000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))
TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)
We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.
There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.
Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian \(\mathcal{N}(0,1)\) distribution.
initial_params = {
'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),
}
def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray:
x = jnp.dot(x, params['hidden'])
x = jax.nn.relu(x)
x = jnp.dot(x, params['output'])
return x
def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
y_hat = net(batch, params)
# optax also provides a number of common loss functions.
loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)
return loss_value.mean()
We will use optax.adam
to compute the parameter updates from their gradients on each optimizer step.
Note that since Optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
step 0, loss: 11.816939353942871
step 100, loss: 1.0228036642074585
step 200, loss: 0.22387900948524475
step 300, loss: 0.0383874736726284
step 400, loss: 0.03373744711279869
step 500, loss: 0.004626731853932142
step 600, loss: 0.0013608938315883279
step 700, loss: 0.0043336618691682816
step 800, loss: 0.005580856930464506
step 900, loss: 0.00279889814555645
We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network.
Weight Decay, Schedules and Clipping#
Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as optax.adam
and optax.clip
.
In the following, we will use Adam with weight decay (optax.adamw
), a cosine learning rate schedule (with warmup) and also gradient clipping.
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1.0,
warmup_steps=50,
decay_steps=1_000,
end_value=0.0,
)
optimizer = optax.chain(
optax.clip(1.0),
optax.adamw(learning_rate=schedule),
)
params = fit(initial_params, optimizer)
step 0, loss: 11.816939353942871
step 100, loss: 6.888598704790638e-07
step 200, loss: 6.104054932620784e-07
step 300, loss: 2.5261655878239253e-08
step 400, loss: 1.0381097581557697e-06
step 500, loss: 1.0905824687768018e-08
step 600, loss: 4.3513332721083556e-14
step 700, loss: 3.4981277874379657e-09
step 800, loss: 1.0796670224522131e-08
step 900, loss: 1.4792466807023175e-08
Components#
We refer to the docs for a detailed list of available Optax components. Here, we highlight the main categories of building blocks provided by Optax.
Gradient Transformations (transform.py)#
One of the key building blocks of Optax is a GradientTransformation
. Each transformation is defined by two functions:
state = init(params)
grads, state = update(grads, state, params=None)
The init
function initializes a (possibly empty) set of statistics (aka state) and the update
function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.
For example:
tx = optax.scale_by_rms()
state = tx.init(params) # init stats
grads = jax.grad(loss)(params, TRAINING_DATA, LABELS)
updates, state = tx.update(grads, state, params) # transform & update stats.
Composing Gradient Transformations (combine.py)#
The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.
For instance, chain
combines them sequentially, and returns a new GradientTransformation
that applies several transformations in sequence.
For example:
max_norm = 100.
learning_rate = 1e-3
my_optimiser = optax.chain(
optax.clip_by_global_norm(max_norm),
optax.scale_by_adam(eps=1e-4),
optax.scale(-learning_rate))
Wrapping Gradient Transformations (wrappers.py)#
Optax also provides several wrappers that take a GradientTransformation
as input and return a new GradientTransformation
that modifies the behaviour of the inner transformation in a specific way.
For instance, the flatten
wrapper flattens gradients into a single large vector before applying the inner GradientTransformation
. The transformed updates are then unflattened before being returned to the user. This can be used to reduce the overhead of performing many calculations on lots of small variables, at the cost of increasing memory usage.
For example:
my_optimiser = optax.flatten(optax.adam(learning_rate))
Other examples of wrappers include accumulating gradients over multiple steps or applying the inner transformation only to specific parameters or at specific steps.
Schedules (schedule.py)#
Many popular transformations use time-dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose schedules that can be used to decay scalars as a function of a step
count.
For example, you may use a polynomial_schedule
(with power=1
) to decay a hyper-parameter linearly over a number of steps:
schedule_fn = optax.polynomial_schedule(
init_value=1., end_value=0., power=1, transition_steps=5)
for step_count in range(6):
print(schedule_fn(step_count)) # [1., 0.8, 0.6, 0.4, 0.2, 0.]
1.0
0.8
0.6
0.39999998
0.19999999
0.0
Schedules can be combined with other transforms as follows.
schedule_fn = optax.polynomial_schedule(
init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
optimiser = optax.chain(
optax.clip_by_global_norm(max_norm),
optax.scale_by_adam(eps=1e-4),
optax.scale_by_schedule(schedule_fn))
Schedules can also be used in place of the learning_rate
argument of a
GradientTransformation
as
optimiser = optax.adam(learning_rate=schedule_fn)
Popular optimisers (alias.py)#
In addition to the low-level building blocks, we also provide aliases for popular optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, …). These are all still instances of a GradientTransformation
, and can therefore be further combined with any of the individual building blocks.
For example:
def adamw(learning_rate, b1, b2, eps, weight_decay):
return optax.chain(
optax.scale_by_adam(b1=b1, b2=b2, eps=eps),
optax.scale_and_decay(-learning_rate, weight_decay=weight_decay))
Applying updates (update.py)#
After transforming an update using a GradientTransformation
or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using tree_map
.
For convenience, we expose an apply_updates
function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. tree_map(lambda p, u: p + u, params, updates)
.
updates, state = tx.update(grads, state, params) # transform & update stats.
new_params = optax.apply_updates(params, updates) # update the parameters.
Note that separating gradient transformations from the parameter update is critical to support composing a sequence of transformations (e.g. chain
), as well as combining multiple updates to the same parameters (e.g. in multi-task settings where different tasks need different sets of gradient transformations).
Losses (loss.py)#
Optax provides a number of standard losses used in deep learning, such as l2_loss
, softmax_cross_entropy
, cosine_distance
, etc.
predictions = net(TRAINING_DATA, params)
loss = optax.huber_loss(predictions, LABELS)
The losses accept batches as inputs, however, they perform no reduction across the batch dimension(s). This is trivial to do in JAX, for example:
avg_loss = jnp.mean(optax.huber_loss(predictions, LABELS))
sum_loss = jnp.sum(optax.huber_loss(predictions, LABELS))
Second Order (second_order.py)#
Computing the Hessian or Fisher information matrices for neural networks is typically intractable due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements.
Stochastic gradient estimators (stochastic_gradient_estimators.py)#
Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution’s parameters.
Unbiased estimators, such as the score function estimator (REINFORCE), pathwise estimator (reparameterization trick) or measure valued estimator, are implemented: score_function_jacobians
, pathwise_jacobians
and measure_valued_jacobians
. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.
Stochastic gradient estimators can be combined with common control variates for variance reduction via control_variates_jacobians
. For provided control variates see control_delta_method
and moving_avg_baseline
.
The result of a gradient estimator or control_variates_jacobians
contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters or to assess gradient variance.
Example of how to use the pathwise_jacobians
estimator:
mean, log_scale, rng, num_samples = 0., 1., jax.random.PRNGKey(0), 100
dist_params = [mean, log_scale]
function = lambda x: jnp.sum(x)
jacobians = optax.monte_carlo.pathwise_jacobians(
function, dist_params,
optax.multi_normal, rng, num_samples)
mean_grads = jnp.mean(jacobians[0], axis=0)
log_scale_grads = jnp.mean(jacobians[1], axis=0)
grads = [mean_grads, log_scale_grads]
optim = optax.adam(1e-3)
optim_state = optim.init(grads)
optim_update, optim_state = optim.update(grads, optim_state)
updated_dist_params = optax.apply_updates(dist_params, optim_update)
where optim
is an Optax optimizer.