optax.microbatching.microbatch

Contents

optax.microbatching.microbatch#

optax.microbatching.microbatch(fun: Callable[[...], Any], argnums: int | Sequence[int], microbatch_size: int | None, accumulator: Accumulator | AccumulationType | Any = AccumulationType.SUM, *, argnames: str | Sequence[str] = (), in_axes: int | Sequence[int] = 0, num_real_microbatches: int | Array | None = None) Callable[[...], Any][source]#

A general microbatching transformation.

Conceptually, given fun, this function returns a new function that does something like the following (for the case of SUM accumulator):

def microbatched_fun(full_batch):
  accumulator = 0
  for microbatch in full_batch:
    accumulator += fun(microbatch)
  return accumulator

where under the hood the for is implemented via a lax.fori_loop and hence forced to be sequential.

This function is useful when evaluating fun on the full input batch exceeds available device memory. By splitting the batch into smaller microbatches and processing them sequentially, peak memory usage can be significantly reduced. Because the function is evaluated on smaller batches, this transformation requires knowledge of how the individual microbatch results should be combined back together (SUM, MEAN, or CONCAT). See the accumulator argument for more details.

Note: For standard functions that do not unpack their positional or keyword arguments like f(a, b), one can specify to microbatch either argument via argnums or argnames, and the microbatched function can be called using the same conventions as the input function (passing args by position or name), independent of whether the args with batch axes are specified in argnums or argnames. For wrapped functions like f(*args, **kwargs), the positional and keyword arguments passed to the microbatched function must match argnums and argnames respectively. See CallingConventionTest for more details.

Example Usage:
>>> import jax.numpy as jnp
>>> fun = lambda x: (x+1, jnp.sum(3*x))
>>> data = jnp.array([1, 2, 3, 4])
>>> fun(data)
(Array([2, 3, 4, 5], dtype=int32), Array(30, dtype=int32))
>>> strategy = (
...    optax.microbatching.AccumulationType.CONCAT,
...    optax.microbatching.AccumulationType.SUM
... )
>>> microbatched_fun = optax.microbatch(
...    fun, argnums=0, microbatch_size=2, accumulator=strategy
... )
>>> microbatched_fun(data)
(Array([2, 3, 4, 5], dtype=int32), Array(30, dtype=int32))

Note

microbatch is compatible with other JAX transformations like jax.grad and jax.vmap. However, when computing gradients, it is generally more efficient to microbatch the gradient than to differentiate through the microbatched function. That is, prefer:

microbatch(jax.grad(loss_fn), ...)(params, batch)

over:

jax.grad(microbatch(loss_fn, ...))(params, batch)

Both produce equivalent results for linear accumulators (SUM, MEAN), but jax.grad(microbatch(...)) differentiates through the internal jax.lax.fori_loop, which requires JAX to save or rematerialize all intermediate loop carries for the backward pass. In contrast, microbatch(jax.grad(...)) computes per-microbatch gradients and accumulates them directly, avoiding this overhead.

Parameters:
  • fun โ€“ An arbitrary function.

  • argnums โ€“ A sequence of argument indices that have a batch axis.

  • microbatch_size โ€“ The number of rows in the overall batch used in each microbatch. Smaller values reduce memory overhead, but require more sequential computation. This must evenly divide the batch axis size of the batch arguments.

  • accumulator โ€“ Specifies how to combine results from each microbatch; can be a single Accumulator, a pytree matching the structure of funโ€™s output, with Accumulator values at the leaves, or anything in between (i.e., a PyTree prefix of funโ€™s output`).

  • argnames โ€“ A sequence of keyword argument names that have a batch axis.

  • in_axes โ€“ An integer or sequence of integers indicating the batch axis index for each argument in argnums and argnames should be aligned with the list argnums + argnames. The default value of 0 assumes that all arguments have a batch axis on the 0th dimension of the array.

  • num_real_microbatches โ€“ Optional number of microbatches that are actually executed. If specified, microbatching will terminate early after this many steps. Can be helpful to handle variable batch sizes without recompilation.

Returns:

A new function that evaluates fun sequentially num_microbatches times on

subsets of data. Consumes the same args and kwargs as fun.