Lookahead Optimizer on MNIST#

This notebook trains a simple Convolution Neural Network (CNN) for hand-written digit recognition (MNIST dataset) using the Lookahead optimizer.

from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import Sequence

# Show on which platform JAX is running.
print("JAX running on", jax.devices()[0].platform.upper())
# @markdown The learning rate for the fast optimizer:
FAST_LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown The learning rate for the slow optimizer:
SLOW_LEARNING_RATE = 0.5 # @param{type:"number"}
# @markdown Number of fast optimizer steps to take before synchronizing parameters:
SYNC_PERIOD = 5 # @param{type:"integer"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 128 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 5 # @param{type:"integer"}

MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using tensorflow_datasets, apply min-max normalization to images, shuffle the data in the train set and create batches of size BATCH_SIZE.

(train_loader, test_loader), info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape

min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)

train_loader_batched = train_loader.shuffle(
    buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)

test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)
The data is ready! Next let’s define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple CNN.

class CNN(nn.Module):
  """A simple CNN model."""
  hidden_sizes: Sequence[int] = (1000, 1000)

  def __call__(self, x):
    x = nn.Conv(features=IMG_SIZE[0], kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=self.hidden_sizes[0], kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=self.hidden_sizes[1])(x)
    x = nn.relu(x)
    x = nn.Dense(features=NUM_CLASSES)(x)
    return x
net = CNN()

def predict(params, inputs):
  return net.apply({'params': params}, inputs)

def loss_accuracy(params, data):
  """Computes loss and accuracy over a mini-batch.

    params: parameters of the model.
    bn_params: state of the model.
    data: tuple of (inputs, labels).
    is_training: if true, uses train mode, otherwise uses eval mode.

    loss: float
  inputs, labels = data
  logits = predict(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=labels
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
  return loss, {"accuracy": accuracy}

def update_model(state, grads):
  return state.apply_gradients(grads=grads)

Next we need to initialize CNN parameters and solver state. We also define a convenience function dataset_stats that we’ll call once per epoch to collect the loss and accuracy of our solver over the test set. We will be using the Lookahead optimizer. Its wrapper keeps a pair of slow and fast parameters. To initialize them, we create a pair of synchronized parameters from the initial model parameters.

fast_solver = optax.adam(FAST_LEARNING_RATE)
solver = optax.lookahead(fast_solver, SYNC_PERIOD, SLOW_LEARNING_RATE)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)

params = net.init({"params": rng}, dummy_data)["params"]

# Initializes the lookahead optimizer with the initial model parameters.
params = optax.LookaheadParams.init_synced(params)
solver_state = solver.init(params)

def dataset_stats(params, data_loader):
  """Computes loss and accuracy over the dataset `data_loader`."""
  all_accuracy = []
  all_loss = []
  for batch in data_loader.as_numpy_iterator():
    batch_loss, batch_aux = loss_accuracy(params, batch)
  return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}

Finally, we do the actual training. The next cell train the model for N_EPOCHS. Within each epoch we iterate over the batched loader train_loader_batched, and once per epoch we also compute the test set accuracy and loss.

train_accuracy = []
train_losses = []

# Computes test set accuracy at initialization.
test_stats = dataset_stats(params.slow, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]

def train_step(params, solver_state, batch):
  # Performs a one step update.
  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
      params.fast, batch
  updates, solver_state = solver.update(grad, solver_state, params)
  params = optax.apply_updates(params, updates)
  return params, solver_state, loss, aux

for epoch in range(N_EPOCHS):
  train_accuracy_epoch = []
  train_losses_epoch = []

  for train_batch in train_loader_batched.as_numpy_iterator():
    params, solver_state, train_loss, train_aux = train_step(
        params, solver_state, train_batch

  # Validation is done on the slow lookahead parameters.
  test_stats = dataset_stats(params.slow, test_loader_batched)
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"
'Improved accuracy on test DS from 0.06760817021131516 to 0.9924879670143127'