Gradient Accumulation#

Gradient accumulation is a technique where the gradients for several consecutive optimization steps are combined together, so that they can be applied at regular repeating intervals.

One example where this is useful is to simulate training with a larger batch size than would fit into the available device memory. Another example is in the context of multi-task learning, where batches for different tasks may be visited in a round-robin fashion. Gradient accumulation makes it possible to simulate training on one large batch containing all of the tasks together.

In this example, we give an example of implementing gradient accumulation using optax.MultiSteps. We start by bringing in some imports and defining some type annotations.

import functools
from typing import Callable, Iterable, Tuple, TypedDict

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import chex


class MiniBatch(TypedDict):
  image: jnp.ndarray
  label: jnp.ndarray


UpdateFn = Callable[[hk.Params, optax.OptState, MiniBatch],
                    Tuple[hk.Params, optax.OptState]]

The following implements a network and loss function that could be used in an image classification problem.

@hk.transform
def net(image: jnp.ndarray) -> jnp.ndarray:
  """A Haiku parameterized function, based on an MLP."""
  features = image.reshape((image.shape[0], -1))
  return hk.nets.MLP([32, 32, 10])(features)


def loss_fn(params: hk.Params, batch: MiniBatch) -> jnp.ndarray:
  """Computes softmax cross entropy for the net outputs batch."""
  logits = net.apply(params, jax.random.PRNGKey(0), batch['image'])
  return optax.softmax_cross_entropy_with_integer_labels(
      logits, batch['label']).mean()

We implement a training loop to perform gradient descent as follows.

def build_update_fn(optimizer: optax.GradientTransformation) -> UpdateFn:
  """Builds a function for executing a single step in the optimization."""

  @jax.jit
  def update(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

  return update


def fit(
    optimizer: optax.GradientTransformation,
    params: hk.Params,
    batches: Iterable[MiniBatch],
) -> hk.Params:
  """Executes a train loop over the train batches using the given optimizer."""

  update_fn = build_update_fn(optimizer)
  opt_state = optimizer.init(params)

  for batch in batches:
    params, opt_state = update_fn(params, opt_state, batch)

  return params

The following generates some random image-like data to test with our networks. The shapes used here correspond to the shapes that might appear in an MNIST classifier.

We also initialize some parameters and a base optimizer to share through the following examples.

EXAMPLES = jax.random.uniform(jax.random.PRNGKey(0), (9, 28, 28, 1))
LABELS = jax.random.randint(jax.random.PRNGKey(0), (9,), minval=0, maxval=10)

optimizer = optax.sgd(1e-4)
params = net.init(jax.random.PRNGKey(0), EXAMPLES)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-69c300c2f2a6> in <module>
      3 
      4 optimizer = optax.sgd(1e-4)
----> 5 params = net.init(jax.random.PRNGKey(0), EXAMPLES)

~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/transform.py in init_fn(*args, **kwargs)
    112 
    113   def init_fn(*args, **kwargs):
--> 114     params, state = f.init(*args, **kwargs)
    115     if state:
    116       raise ValueError("If your transformed function uses `hk.{get,set}_state` "

~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/transform.py in init_fn(rng, *args, **kwargs)
    336     with base.new_context(rng=rng) as ctx:
    337       try:
--> 338         f(*args, **kwargs)
    339       except jax.errors.UnexpectedTracerError as e:
    340         raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

<ipython-input-2-c0594235f123> in net(image)
      3   """A Haiku parameterized function, based on an MLP."""
      4   features = image.reshape((image.shape[0], -1))
----> 5   return hk.nets.MLP([32, 32, 10])(features)
      6 
      7 

~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/module.py in __call__(cls, *args, **kwargs)
    121     # Now attempt to initialize the object.
    122     init = wrap_method("__init__", cls.__init__)
--> 123     init(module, *args, **kwargs)
    124 
    125     if (config.get_config().module_auto_repr and

~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
    419       f = functools.partial(unbound_method, self)
    420       f = functools.partial(run_interceptors, f, method_name, self)
--> 421       if jax.config.jax_experimental_name_stack and module_name:
    422         local_module_name = module_name.split("/")[-1]
    423         f = jax.named_call(f, name=local_module_name)

AttributeError: 'Config' object has no attribute 'jax_experimental_name_stack'

Splitting updates for one batch over multiple steps#

The following two snippets will compute numerically identical results, but with the difference that the second snippet will use gradient accumulation over three batches to mimic the first snippet, which performs a single step with one large batch.

We start with the snippet that runs a training loop over a single batch containing all examples,

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

In this second snippet, our training loop will execute three training steps that together also contain all of the examples. In this case, the optimizer is wrapped with optax.MultiSteps, with every_k_schedule=3. This means that instead of applying gradient updates directly, the raw gradients will be combined together until the third step, where the wrapped optimizer will be applied to the average over the raw gradients seen up until now. For the “interim” steps, the updates returned by the optimizer will be all-zeros, resulting in no change to the parameters during these steps.

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

We can now verify that both training loops compute identical results as follows.

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation,
    atol=1e-7,
)

Interaction of optax.MultiStep with schedules.#

The snippet below is identical to the snippet above, except we additionally introduce a learning rate schedule. As above, the second call to fit is using gradient accumulation. Similarly to before, we find that both train loops compute compute identical outputs (up to numerical errors).

This happens because the learning rate schedule in optax.MultiStep is only updated once for each of the outer steps. In particular, the state of the inner optimizer is only updated each time every_k_schedule optimizer steps have been taken.

learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=1.0,
    boundaries_and_scales={
        0: 1e-4,
        1: 1e-1,
    },
)

optimizer = optax.sgd(learning_rate_schedule)

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation,
    atol=1e-7,
)