optax.microbatching.micro_vmap

Contents

optax.microbatching.micro_vmap#

optax.microbatching.micro_vmap(fun: ~typing.Callable[[...], ~typing.Any], in_axes: int | ~typing.Sequence[int] = 0, out_axes: ~typing.Any = 0, *, microbatch_size: int | None = None, vmap_fn: ~typing.Callable[[~typing.Callable[[...], ~typing.Any], int | ~typing.Sequence[int], int], ~typing.Callable[[...], ~typing.Any]] = <function vmap>, accumulator: ~optax.microbatching._microbatching.Accumulator | ~optax.microbatching._microbatching.AccumulationType | ~typing.Any = AccumulationType.CONCAT, num_real_microbatches: int | ~jax.jaxlib._jax.Array | None = None) Callable[[...], Any][source]#

A generalized version of jax.vmap that supports microbatching.

Because this function incorporates microbatching, you can vmap over arrays with much larger batch axis sizes than jax.vmap without running out of memory. This function generalizes vmap by introducing new keyword arguments microbatch_size and accumulator to control microbatching behavior. It specializes vmap by imposing stricter requirements on in_axes and out_axes.

Example Usage:
>>> import optax
>>> import jax.numpy as jnp
>>> optax.microbatching.micro_vmap(lambda x: x**2)(jnp.arange(8))
Array([ 0,  1,  4,  9, 16, 25, 36, 49], dtype=int32)
Parameters:
  • fun โ€“ Function to be mapped over additional axes.

  • in_axes โ€“ Array axis to map over. See jax.vmap for more details.

  • out_axes โ€“ Unsupported by optax.vmap, must be set to 0.

  • microbatch_size โ€“ The number of rows in the overall batch used in each microbatch. Smaller values reduces memory overhead, but require more sequential computation. This must evenly divide the batch axis size of the batch arguments.

  • vmap_fn โ€“ A function with the same signature as jax.vmap. Can be used to e.g., pass in kwargs to vmap.

  • accumulator โ€“ Specifies what to do with the vmapped outputs. The default value (CONCAT) returns each output with a batch axis, matching the behavior of jax.vmap. Reductions over the batch axis are also possible, including MEAN and SUM, and can be used when the the full output with a batch axis is not needed and is too large to fit in memory. This accumulator can be any PyTree prefix of the outputs of fun to apply different reductions to different sub-trees.

  • num_real_microbatches โ€“ Optional number of microbatches that are actually executed. If specified, microbatching will terminate early after this many steps. Can be helpful to handle variable batch sizes without recompilation.

Returns:

A new function with the same args and kwargs having an additional batch axis (according to in_axes).