MLP MNIST#
This notebook trains a simple Multilayer Perceptron (MLP) classifier for hand-written digit recognition (MNIST dataset).
To run the colab locally you need install grain via pip.
from typing import Sequence
import jax
import jax.numpy as jnp
import optax
import numpy as np
from flax import nnx
import grain.python as pygrain
from torchvision.datasets import MNIST
import torchvision.transforms as T
# @markdown The learning rate for the optimizer:
LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 128 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 1 # @param{type:"integer"}
# Number of classes (digits 0-9)
N_TARGETS = 10
# Input size (MNIST images are 28x28 pixels)
IMG_SIZE = 28 * 28
# Directory for storing the dataset
DATA_DIR = '/tmp/mnist_dataset'
Data Loading#
MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using torchvision, apply min-max normalization to the images, shuffle the data in the train set and create batches of size BATCH_SIZE using grain.
# Define the transformation
torch_transforms = T.Compose([
T.ToTensor(),
T.Lambda(lambda x: x.ravel()), # Flattens to (784,)
])
class Dataset:
def __init__(self, data_dir, train=True):
self.data_dir = data_dir
self.train = train
self.load_data()
def load_data(self):
self.dataset = MNIST(self.data_dir, download=True, train=self.train,
transform=torch_transforms)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img, label = self.dataset[index]
return np.array(img, dtype=np.float32), label
Initialize the datasets
mnist_dataset_train = Dataset(DATA_DIR, train=True)
mnist_dataset_test = Dataset(DATA_DIR, train=False)
print(f"Train dataset size: {len(mnist_dataset_train)}")
print(f"Test dataset size: {len(mnist_dataset_test)}")
Initialize PyGrain DataLoaders
train_sampler = pygrain.SequentialSampler(
num_records=len(mnist_dataset_train),
shard_options=pygrain.NoSharding()
)
train_loader_batched = pygrain.DataLoader(
data_source=mnist_dataset_train,
sampler=train_sampler,
operations=[pygrain.Batch(batch_size=BATCH_SIZE, drop_remainder=True)],
)
test_sampler = pygrain.SequentialSampler(
num_records=len(mnist_dataset_test),
shard_options=pygrain.NoSharding()
)
test_loader = pygrain.DataLoader(
data_source=mnist_dataset_test,
sampler=test_sampler,
operations=[pygrain.Batch(batch_size=BATCH_SIZE, drop_remainder=True)],
)
Define MLP Model#
The data is ready! Next let’s define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax NNX to implement a simple MLP.
class MLP(nnx.Module):
"""A simple multilayer perceptron model for image classification."""
def __init__(self, num_inputs: int, num_classes: int, hidden_sizes:
Sequence[int], *, rngs: nnx.Rngs):
self.hidden_sizes = hidden_sizes
self.layer1 = nnx.Linear(num_inputs, self.hidden_sizes[0], rngs=rngs)
self.layer2 = nnx.Linear(self.hidden_sizes[0], self.hidden_sizes[1],
rngs=rngs)
self.layer_out = nnx.Linear(self.hidden_sizes[1], num_classes, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.layer1(x))
x = nnx.relu(self.layer2(x))
x = self.layer_out(x)
return x
def compute_loss_and_accuracy(model, batch):
inputs, labels = batch
logits = model(inputs)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
return loss, accuracy
@nnx.jit
def train_step(model, optimizer, batch):
"""Performs a one step update."""
grad_fn = nnx.value_and_grad(compute_loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(model, batch)
# In-place update of the model parameters
optimizer.update(grads)
return loss, {"accuracy": accuracy}
@nnx.jit
def eval_step(model, batch):
loss, accuracy = compute_loss_and_accuracy(model, batch)
return loss, {"accuracy": accuracy}
Next we need to initialize network parameters and solver state. We also define a convenience function dataset_stats that we’ll call once per epoch to collect the loss and accuracy of our solver over the test set.
# Initialize RNGs
rngs = nnx.Rngs(0)
# Create the Model
model = MLP(num_inputs=IMG_SIZE, num_classes=N_TARGETS,
hidden_sizes=[1000, 1000], rngs=rngs)
# Create the Optimizer
solver = optax.adam(LEARNING_RATE)
optimizer = nnx.Optimizer(model, solver)
def dataset_stats(model, data_loader):
"""Computes loss and accuracy over the dataset."""
all_accuracy = []
all_loss = []
for batch in data_loader:
loss, aux = eval_step(model, batch)
all_loss.append(loss)
all_accuracy.append(aux["accuracy"])
return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}
Training Loop#
Finally, we do the actual training. The next cell train the model for N_EPOCHS. Within each epoch we iterate over the batched loader train_loader_batched, and once per epoch we also compute the test set accuracy and loss.
train_accuracy = []
train_losses = []
test_accuracy = []
test_losses = []
# Computes test set accuracy at initialization.
test_stats = dataset_stats(model, test_loader)
test_accuracy.append(test_stats["accuracy"])
test_losses.append(test_stats["loss"])
for epoch in range(N_EPOCHS):
train_accuracy_epoch = []
train_losses_epoch = []
# Iterate over the training dataset
for step, train_batch in enumerate(train_loader_batched):
train_loss, train_aux = train_step(model, optimizer, train_batch)
train_accuracy_epoch.append(train_aux["accuracy"])
train_losses_epoch.append(train_loss)
if step % 20 == 0:
print(
f"Step {step}, train loss: {train_loss:.4f}, train accuracy:"
f" {train_aux['accuracy']:.2f}")
# Record training stats
train_accuracy.append(np.mean(train_accuracy_epoch))
train_losses.append(np.mean(train_losses_epoch))
# Evaluate on test set
test_stats = dataset_stats(model, test_loader)
test_accuracy.append(test_stats["accuracy"])
test_losses.append(test_stats["loss"])
f"Improved test accuracy from {test_accuracy[0]:.2f} to {test_accuracy[-1]:.2f}"