Gradient Accumulation
Contents
Gradient Accumulation#
Gradient accumulation is a technique where the gradients for several consecutive optimization steps are combined together, so that they can be applied at regular repeating intervals.
One example where this is useful is to simulate training with a larger batch size than would fit into the available device memory. Another example is in the context of multi-task learning, where batches for different tasks may be visited in a round-robin fashion. Gradient accumulation makes it possible to simulate training on one large batch containing all of the tasks together.
In this example, we give an example of implementing gradient accumulation using optax.MultiSteps
. We start by bringing in some imports and defining some type annotations.
import functools
from typing import Callable, Iterable, Tuple, TypedDict
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import chex
class MiniBatch(TypedDict):
image: jnp.ndarray
label: jnp.ndarray
UpdateFn = Callable[[hk.Params, optax.OptState, MiniBatch],
Tuple[hk.Params, optax.OptState]]
The following implements a network and loss function that could be used in an image classification problem.
@hk.transform
def net(image: jnp.ndarray) -> jnp.ndarray:
"""A Haiku parameterized function, based on an MLP."""
features = image.reshape((image.shape[0], -1))
return hk.nets.MLP([32, 32, 10])(features)
def loss_fn(params: hk.Params, batch: MiniBatch) -> jnp.ndarray:
"""Computes softmax cross entropy for the net outputs batch."""
logits = net.apply(params, jax.random.PRNGKey(0), batch['image'])
return optax.softmax_cross_entropy_with_integer_labels(
logits, batch['label']).mean()
We implement a training loop to perform gradient descent as follows.
def build_update_fn(optimizer: optax.GradientTransformation) -> UpdateFn:
"""Builds a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
def fit(
optimizer: optax.GradientTransformation,
params: hk.Params,
batches: Iterable[MiniBatch],
) -> hk.Params:
"""Executes a train loop over the train batches using the given optimizer."""
update_fn = build_update_fn(optimizer)
opt_state = optimizer.init(params)
for batch in batches:
params, opt_state = update_fn(params, opt_state, batch)
return params
The following generates some random image-like data to test with our networks. The shapes used here correspond to the shapes that might appear in an MNIST classifier.
We also initialize some parameters and a base optimizer to share through the following examples.
EXAMPLES = jax.random.uniform(jax.random.PRNGKey(0), (9, 28, 28, 1))
LABELS = jax.random.randint(jax.random.PRNGKey(0), (9,), minval=0, maxval=10)
optimizer = optax.sgd(1e-4)
params = net.init(jax.random.PRNGKey(0), EXAMPLES)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-4-69c300c2f2a6> in <module>
3
4 optimizer = optax.sgd(1e-4)
----> 5 params = net.init(jax.random.PRNGKey(0), EXAMPLES)
~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/transform.py in init_fn(*args, **kwargs)
112
113 def init_fn(*args, **kwargs):
--> 114 params, state = f.init(*args, **kwargs)
115 if state:
116 raise ValueError("If your transformed function uses `hk.{get,set}_state` "
~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/transform.py in init_fn(rng, *args, **kwargs)
336 with base.new_context(rng=rng) as ctx:
337 try:
--> 338 f(*args, **kwargs)
339 except jax.errors.UnexpectedTracerError as e:
340 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
<ipython-input-2-c0594235f123> in net(image)
3 """A Haiku parameterized function, based on an MLP."""
4 features = image.reshape((image.shape[0], -1))
----> 5 return hk.nets.MLP([32, 32, 10])(features)
6
7
~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/module.py in __call__(cls, *args, **kwargs)
121 # Now attempt to initialize the object.
122 init = wrap_method("__init__", cls.__init__)
--> 123 init(module, *args, **kwargs)
124
125 if (config.get_config().module_auto_repr and
~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.8/site-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
419 f = functools.partial(unbound_method, self)
420 f = functools.partial(run_interceptors, f, method_name, self)
--> 421 if jax.config.jax_experimental_name_stack and module_name:
422 local_module_name = module_name.split("/")[-1]
423 f = jax.named_call(f, name=local_module_name)
AttributeError: 'Config' object has no attribute 'jax_experimental_name_stack'
Splitting updates for one batch over multiple steps#
The following two snippets will compute numerically identical results, but with the difference that the second snippet will use gradient accumulation over three batches to mimic the first snippet, which performs a single step with one large batch.
We start with the snippet that runs a training loop over a single batch containing all examples,
new_params_single_batch = fit(
optimizer,
params,
batches=[
MiniBatch(image=EXAMPLES, label=LABELS),
],
)
In this second snippet, our training loop will execute three training steps that together also contain all of the examples. In this case, the optimizer is wrapped with optax.MultiSteps
, with every_k_schedule=3
. This means that instead of applying gradient updates directly, the raw gradients will be combined together until the third step, where the wrapped optimizer will be applied to the average over the raw gradients seen up until now. For the “interim” steps, the updates returned by the optimizer will be all-zeros, resulting in no change to the parameters during these steps.
new_params_gradient_accumulation = fit(
optax.MultiSteps(optimizer, every_k_schedule=3),
params,
batches=[
MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
],
)
We can now verify that both training loops compute identical results as follows.
chex.assert_trees_all_close(
new_params_single_batch,
new_params_gradient_accumulation,
atol=1e-7,
)
Interaction of optax.MultiStep
with schedules.#
The snippet below is identical to the snippet above, except we additionally introduce a learning rate schedule. As above, the second call to fit
is using gradient accumulation. Similarly to before, we find that both train loops compute compute identical outputs (up to numerical errors).
This happens because the learning rate schedule in optax.MultiStep
is only updated once for each of the outer steps. In particular, the state of the inner optimizer is only updated each time every_k_schedule
optimizer steps have been taken.
learning_rate_schedule = optax.piecewise_constant_schedule(
init_value=1.0,
boundaries_and_scales={
0: 1e-4,
1: 1e-1,
},
)
optimizer = optax.sgd(learning_rate_schedule)
new_params_single_batch = fit(
optimizer,
params,
batches=[
MiniBatch(image=EXAMPLES, label=LABELS),
],
)
new_params_gradient_accumulation = fit(
optax.MultiSteps(optimizer, every_k_schedule=3),
params,
batches=[
MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
],
)
chex.assert_trees_all_close(
new_params_single_batch,
new_params_gradient_accumulation,
atol=1e-7,
)