Lookahead Optimizer on MNIST

Lookahead Optimizer on MNIST#

Open in Colab

This notebook trains a simple Convolution Neural Network (CNN) for hand-written digit recognition (MNIST dataset) using the Lookahead optimizer.

from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
2024-04-29 08:11:32.665518: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/docs/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
# @markdown The learning rate for the fast optimizer:
FAST_LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown The learning rate for the slow optimizer:
SLOW_LEARNING_RATE = 0.1 # @param{type:"number"}
# @markdown Number of fast optimizer steps to take before synchronizing parameters:
SYNC_PERIOD = 5 # @param{type:"integer"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 256 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 1 # @param{type:"integer"}

MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using tensorflow_datasets, apply min-max normalization to images, shuffle the data in the train set and create batches of size BATCH_SIZE.

(train_loader, test_loader), info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
)
NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape

min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)

train_loader_batched = train_loader.shuffle(
    buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)

test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)
2024-04-29 08:11:33.678892: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/docs/tensorflow_datasets/mnist/3.0.1...
Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]
Dl Completed...:  20%|β–ˆβ–ˆ        | 1/5 [00:00<00:01,  3.85 file/s]
Dl Completed...:  20%|β–ˆβ–ˆ        | 1/5 [00:00<00:01,  3.85 file/s]
Dl Completed...:  40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 2/5 [00:00<00:00,  3.85 file/s]
Dl Completed...:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 3/5 [00:00<00:00,  3.85 file/s]
Dl Completed...:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 4/5 [00:00<00:00,  9.46 file/s]
Dl Completed...:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 4/5 [00:00<00:00,  9.46 file/s]
Dl Completed...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00,  9.46 file/s]
Dl Completed...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00,  7.33 file/s]
Dataset mnist downloaded and prepared to /home/docs/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

The data is ready! Next let’s define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple CNN.

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x
net = CNN()

@jax.jit
def predict(params, inputs):
  return net.apply({'params': params}, inputs)


@jax.jit
def loss_accuracy(params, data):
  """Computes loss and accuracy over a mini-batch.

  Args:
    params: parameters of the model.
    bn_params: state of the model.
    data: tuple of (inputs, labels).
    is_training: if true, uses train mode, otherwise uses eval mode.

  Returns:
    loss: float
  """
  inputs, labels = data
  logits = predict(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=labels
  ).mean()
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
  return loss, {"accuracy": accuracy}

@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

Next we need to initialize CNN parameters and solver state. We also define a convenience function dataset_stats that we’ll call once per epoch to collect the loss and accuracy of our solver over the test set. We will be using the Lookahead optimizer. Its wrapper keeps a pair of slow and fast parameters. To initialize them, we create a pair of synchronized parameters from the initial model parameters.

fast_solver = optax.adam(FAST_LEARNING_RATE)
solver = optax.lookahead(fast_solver, SYNC_PERIOD, SLOW_LEARNING_RATE)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)

params = net.init({"params": rng}, dummy_data)["params"]

# Initializes the lookahead optimizer with the initial model parameters.
params = optax.LookaheadParams.init_synced(params)
solver_state = solver.init(params)

def dataset_stats(params, data_loader):
  """Computes loss and accuracy over the dataset `data_loader`."""
  all_accuracy = []
  all_loss = []
  for batch in data_loader.as_numpy_iterator():
    batch_loss, batch_aux = loss_accuracy(params, batch)
    all_loss.append(batch_loss)
    all_accuracy.append(batch_aux["accuracy"])
  return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}

Finally, we do the actual training. The next cell train the model for N_EPOCHS. Within each epoch we iterate over the batched loader train_loader_batched, and once per epoch we also compute the test set accuracy and loss.

train_accuracy = []
train_losses = []

# Computes test set accuracy at initialization.
test_stats = dataset_stats(params.slow, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]


@jax.jit
def train_step(params, solver_state, batch):
  # Performs a one step update.
  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
      params.fast, batch
  )
  updates, solver_state = solver.update(grad, solver_state, params)
  params = optax.apply_updates(params, updates)
  return params, solver_state, loss, aux


for epoch in range(N_EPOCHS):
  train_accuracy_epoch = []
  train_losses_epoch = []

  for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):
    params, solver_state, train_loss, train_aux = train_step(
        params, solver_state, train_batch
    )
    train_accuracy_epoch.append(train_aux["accuracy"])
    train_losses_epoch.append(train_loss)
    if step % 20 == 0:
      print(
          f"step {step}, train loss: {train_loss:.2e}, train accuracy:"
          f" {train_aux['accuracy']:.2f}"
      )

  # Validation is done on the slow lookahead parameters.
  test_stats = dataset_stats(params.slow, test_loader_batched)
  test_accuracy.append(test_stats["accuracy"])
  test_losses.append(test_stats["loss"])
  train_accuracy.append(np.mean(train_accuracy_epoch))
  train_losses.append(np.mean(train_losses_epoch))
2024-04-29 08:11:41.732250: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
step 0, train loss: 2.30e+00, train accuracy: 0.11
step 20, train loss: 1.74e+00, train accuracy: 0.77
step 40, train loss: 1.06e+00, train accuracy: 0.82
step 60, train loss: 7.06e-01, train accuracy: 0.86
step 80, train loss: 5.63e-01, train accuracy: 0.86
step 100, train loss: 4.33e-01, train accuracy: 0.88
step 120, train loss: 3.94e-01, train accuracy: 0.90
step 140, train loss: 3.24e-01, train accuracy: 0.91
step 160, train loss: 3.36e-01, train accuracy: 0.91
step 180, train loss: 3.23e-01, train accuracy: 0.91
step 200, train loss: 3.34e-01, train accuracy: 0.90
step 220, train loss: 2.18e-01, train accuracy: 0.95
2024-04-29 08:13:03.140888: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-29 08:13:07.313099: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"
'Improved accuracy on test DS from 0.1041666641831398 to 0.930588960647583'