MLP MNIST#
This notebook trains a simple Multilayer Perceptron (MLP) classifier for hand-written digit recognition (MNIST dataset).
from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import Sequence
2024-04-29 08:13:18.000451: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/docs/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
# @markdown The learning rate for the optimizer:
LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 128 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 1 # @param{type:"integer"}
MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using tensorflow_datasets
, apply min-max normalization to the images, shuffle the data in the train set and create batches of size BATCH_SIZE
.
(train_loader, test_loader), info = tfds.load(
"mnist", split=["train", "test"], as_supervised=True, with_info=True
)
min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)
NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape
train_loader_batched = train_loader.shuffle(
buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)
test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)
The data is ready! Next let’s define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple MLP.
class MLP(nn.Module):
"""A simple multilayer perceptron model for image classification."""
hidden_sizes: Sequence[int] = (1000, 1000)
@nn.compact
def __call__(self, x):
# Flattens images in the batch.
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=self.hidden_sizes[0])(x)
x = nn.relu(x)
x = nn.Dense(features=self.hidden_sizes[1])(x)
x = nn.relu(x)
x = nn.Dense(features=NUM_CLASSES)(x)
return x
net = MLP()
@jax.jit
def predict(params, inputs):
return net.apply({"params": params}, inputs)
@jax.jit
def loss_accuracy(params, data):
"""Computes loss and accuracy over a mini-batch.
Args:
params: parameters of the model.
bn_params: state of the model.
data: tuple of (inputs, labels).
is_training: if true, uses train mode, otherwise uses eval mode.
Returns:
loss: float
"""
inputs, labels = data
logits = predict(params, inputs)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
return loss, {"accuracy": accuracy}
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
Next we need to initialize network parameters and solver state. We also define a convenience function dataset_stats
that we’ll call once per epoch to collect the loss and accuracy of our solver over the test set.
solver = optax.adam(LEARNING_RATE)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)
params = net.init({"params": rng}, dummy_data)["params"]
solver_state = solver.init(params)
def dataset_stats(params, data_loader):
"""Computes loss and accuracy over the dataset `data_loader`."""
all_accuracy = []
all_loss = []
for batch in data_loader.as_numpy_iterator():
batch_loss, batch_aux = loss_accuracy(params, batch)
all_loss.append(batch_loss)
all_accuracy.append(batch_aux["accuracy"])
return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}
Finally, we do the actual training. The next cell train the model for N_EPOCHS
. Within each epoch we iterate over the batched loader train_loader_batched
, and once per epoch we also compute the test set accuracy and loss.
train_accuracy = []
train_losses = []
# Computes test set accuracy at initialization.
test_stats = dataset_stats(params, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]
@jax.jit
def train_step(params, solver_state, batch):
# Performs a one step update.
(loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
params, batch
)
updates, solver_state = solver.update(grad, solver_state, params)
params = optax.apply_updates(params, updates)
return params, solver_state, loss, aux
for epoch in range(N_EPOCHS):
train_accuracy_epoch = []
train_losses_epoch = []
for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):
params, solver_state, train_loss, train_aux = train_step(
params, solver_state, train_batch
)
train_accuracy_epoch.append(train_aux["accuracy"])
train_losses_epoch.append(train_loss)
if step % 20 == 0:
print(
f"step {step}, train loss: {train_loss:.2e}, train accuracy:"
f" {train_aux['accuracy']:.2f}"
)
test_stats = dataset_stats(params, test_loader_batched)
test_accuracy.append(test_stats["accuracy"])
test_losses.append(test_stats["loss"])
train_accuracy.append(np.mean(train_accuracy_epoch))
train_losses.append(np.mean(train_losses_epoch))
2024-04-29 08:13:21.369685: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
step 0, train loss: 2.30e+00, train accuracy: 0.13
step 20, train loss: 2.73e-01, train accuracy: 0.94
step 40, train loss: 2.14e-01, train accuracy: 0.95
step 60, train loss: 2.31e-01, train accuracy: 0.94
step 80, train loss: 1.94e-01, train accuracy: 0.92
step 100, train loss: 1.76e-01, train accuracy: 0.94
step 120, train loss: 2.59e-01, train accuracy: 0.90
step 140, train loss: 2.27e-01, train accuracy: 0.91
step 160, train loss: 1.15e-01, train accuracy: 0.96
step 180, train loss: 1.72e-01, train accuracy: 0.95
step 200, train loss: 1.33e-01, train accuracy: 0.95
step 220, train loss: 1.39e-01, train accuracy: 0.96
step 240, train loss: 6.54e-02, train accuracy: 0.98
step 260, train loss: 1.06e-01, train accuracy: 0.97
step 280, train loss: 2.52e-01, train accuracy: 0.91
step 300, train loss: 1.69e-01, train accuracy: 0.95
step 320, train loss: 7.72e-02, train accuracy: 0.98
step 340, train loss: 1.35e-01, train accuracy: 0.98
step 360, train loss: 1.52e-01, train accuracy: 0.93
step 380, train loss: 1.19e-01, train accuracy: 0.95
step 400, train loss: 1.69e-01, train accuracy: 0.95
step 420, train loss: 3.04e-01, train accuracy: 0.94
step 440, train loss: 8.07e-02, train accuracy: 0.98
step 460, train loss: 1.23e-01, train accuracy: 0.98
2024-04-29 08:13:35.915038: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-29 08:13:42.508794: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"
'Improved accuracy on test DS from 0.14463141560554504 to 0.97265625'