Gradient Accumulation#

Open in Colab

Gradient accumulation is a technique where the gradients for several consecutive optimization steps are combined together, so that they can be applied at regular repeating intervals.

One example where this is useful is to simulate training with a larger batch size than would fit into the available device memory. Another example is in the context of multi-task learning, where batches for different tasks may be visited in a round-robin fashion. Gradient accumulation makes it possible to simulate training on one large batch containing all of the tasks together.

In this example, we give an example of implementing gradient accumulation using optax.MultiSteps. We start by bringing in some imports and defining some type annotations.

import functools
from typing import Callable, Iterable, Tuple, TypedDict

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import chex


class MiniBatch(TypedDict):
  image: jnp.ndarray
  label: jnp.ndarray


UpdateFn = Callable[[hk.Params, optax.OptState, MiniBatch],
                    Tuple[hk.Params, optax.OptState]]
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 4
      1 import functools
      2 from typing import Callable, Iterable, Tuple, TypedDict
----> 4 import haiku as hk
      5 import jax
      6 import jax.numpy as jnp

ModuleNotFoundError: No module named 'haiku'

The following implements a network and loss function that could be used in an image classification problem.

@hk.transform
def net(image: jnp.ndarray) -> jnp.ndarray:
  """A Haiku parameterized function, based on an MLP."""
  features = image.reshape((image.shape[0], -1))
  return hk.nets.MLP([32, 32, 10])(features)


def loss_fn(params: hk.Params, batch: MiniBatch) -> jnp.ndarray:
  """Computes softmax cross entropy for the net outputs batch."""
  logits = net.apply(params, jax.random.PRNGKey(0), batch['image'])
  return optax.softmax_cross_entropy_with_integer_labels(
      logits, batch['label']).mean()

We implement a training loop to perform gradient descent as follows.

def build_update_fn(optimizer: optax.GradientTransformation) -> UpdateFn:
  """Builds a function for executing a single step in the optimization."""

  @jax.jit
  def update(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

  return update


def fit(
    optimizer: optax.GradientTransformation,
    params: hk.Params,
    batches: Iterable[MiniBatch],
) -> hk.Params:
  """Executes a train loop over the train batches using the given optimizer."""

  update_fn = build_update_fn(optimizer)
  opt_state = optimizer.init(params)

  for batch in batches:
    params, opt_state = update_fn(params, opt_state, batch)

  return params

The following generates some random image-like data to test with our networks. The shapes used here correspond to the shapes that might appear in an MNIST classifier.

We also initialize some parameters and a base optimizer to share through the following examples.

EXAMPLES = jax.random.uniform(jax.random.PRNGKey(0), (9, 28, 28, 1))
LABELS = jax.random.randint(jax.random.PRNGKey(0), (9,), minval=0, maxval=10)

optimizer = optax.sgd(1e-4)
params = net.init(jax.random.PRNGKey(0), EXAMPLES)

Splitting updates for one batch over multiple steps#

The following two snippets will compute numerically identical results, but with the difference that the second snippet will use gradient accumulation over three batches to mimic the first snippet, which performs a single step with one large batch.

We start with the snippet that runs a training loop over a single batch containing all examples,

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

In this second snippet, our training loop will execute three training steps that together also contain all of the examples. In this case, the optimizer is wrapped with optax.MultiSteps, with every_k_schedule=3. This means that instead of applying gradient updates directly, the raw gradients will be combined together until the third step, where the wrapped optimizer will be applied to the average over the raw gradients seen up until now. For the โ€œinterimโ€ steps, the updates returned by the optimizer will be all-zeros, resulting in no change to the parameters during these steps.

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

We can now verify that both training loops compute identical results as follows.

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation,
    atol=1e-7,
)

Interaction of optax.MultiStep with schedules.#

The snippet below is identical to the snippet above, except we additionally introduce a learning rate schedule. As above, the second call to fit is using gradient accumulation. Similarly to before, we find that both train loops compute compute identical outputs (up to numerical errors).

This happens because the learning rate schedule in optax.MultiStep is only updated once for each of the outer steps. In particular, the state of the inner optimizer is only updated each time every_k_schedule optimizer steps have been taken.

learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=1.0,
    boundaries_and_scales={
        0: 1e-4,
        1: 1e-1,
    },
)

optimizer = optax.sgd(learning_rate_schedule)

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation,
    atol=1e-7,
)