Microbatching#
|
A general microbatching transformation. |
|
A generalized version of jax.vmap that supports microbatching. |
|
Create a function to compute, transform, and sum per-example gradients. |
|
Reshape batch axis of pytree leaves for use with microbatching. |
|
The type of accumulation to perform. |
|
A class for accumulating values in a microbatched function. |
- optax.microbatching.microbatch(fun: Callable[[...], Any], argnums: int | Sequence[int], microbatch_size: int | None, accumulator: Accumulator | AccumulationType | Any = AccumulationType.SUM, *, argnames: str | Sequence[str] = (), in_axes: int | Sequence[int] = 0, num_real_microbatches: int | Array | None = None) Callable[[...], Any][source]#
A general microbatching transformation.
Conceptually, given
fun, this function returns a new function that does something like the following (for the case of SUM accumulator):def microbatched_fun(full_batch): accumulator = 0 for microbatch in full_batch: accumulator += fun(microbatch) return accumulator
where under the hood the
foris implemented via alax.fori_loopand hence forced to be sequential.This function is useful when evaluating
funon the full input batch exceeds available device memory. By splitting the batch into smaller microbatches and processing them sequentially, peak memory usage can be significantly reduced. Because the function is evaluated on smaller batches, this transformation requires knowledge of how the individual microbatch results should be combined back together (SUM, MEAN, or CONCAT). See the accumulator argument for more details.Note: For standard functions that do not unpack their positional or keyword arguments like f(a, b), one can specify to microbatch either argument via
argnumsorargnames, and the microbatched function can be called using the same conventions as the input function (passing args by position or name), independent of whether the args with batch axes are specified inargnumsorargnames. For wrapped functions like f(*args, **kwargs), the positional and keyword arguments passed to the microbatched function must match argnums and argnames respectively. See CallingConventionTest for more details.- Example Usage:
>>> import jax.numpy as jnp >>> fun = lambda x: (x+1, jnp.sum(3*x)) >>> data = jnp.array([1, 2, 3, 4]) >>> fun(data) (Array([2, 3, 4, 5], dtype=int32), Array(30, dtype=int32)) >>> strategy = ( ... optax.microbatching.AccumulationType.CONCAT, ... optax.microbatching.AccumulationType.SUM ... ) >>> microbatched_fun = optax.microbatch( ... fun, argnums=0, microbatch_size=2, accumulator=strategy ... ) >>> microbatched_fun(data) (Array([2, 3, 4, 5], dtype=int32), Array(30, dtype=int32))
- Parameters:
fun – An arbitrary function.
argnums – A sequence of argument indices that have a batch axis.
microbatch_size – The number of rows in the overall batch used in each microbatch. Smaller values reduce memory overhead, but require more sequential computation. This must evenly divide the batch axis size of the batch arguments.
accumulator – Specifies how to combine results from each microbatch; can be a single Accumulator, a pytree matching the structure of fun’s output, with Accumulator values at the leaves, or anything in between (i.e., a PyTree prefix of fun’s output`).
argnames – A sequence of keyword argument names that have a batch axis.
in_axes – An integer or sequence of integers indicating the batch axis index for each argument in argnums and argnames should be aligned with the list argnums + argnames. The default value of 0 assumes that all arguments have a batch axis on the 0th dimension of the array.
num_real_microbatches – Optional number of microbatches that are actually executed. If specified, microbatching will terminate early after this many steps. Can be helpful to handle variable batch sizes without recompilation.
- Returns:
- A new function that evaluates fun sequentially num_microbatches times on
subsets of data. Consumes the same args and kwargs as
fun.
- optax.microbatching.micro_vmap(fun: ~typing.Callable[[...], ~typing.Any], in_axes: int | ~typing.Sequence[int] = 0, out_axes: ~typing.Any = 0, *, microbatch_size: int | None = None, vmap_fn: ~typing.Callable[[~typing.Callable[[...], ~typing.Any], int | ~typing.Sequence[int], int], ~typing.Callable[[...], ~typing.Any]] = <function vmap>, accumulator: ~optax.microbatching._microbatching.Accumulator | ~optax.microbatching._microbatching.AccumulationType | ~typing.Any = AccumulationType.CONCAT, num_real_microbatches: int | ~jax.jaxlib._jax.Array | None = None) Callable[[...], Any][source]#
A generalized version of jax.vmap that supports microbatching.
Because this function incorporates microbatching, you can vmap over arrays with much larger batch axis sizes than jax.vmap without running out of memory. This function generalizes vmap by introducing new keyword arguments microbatch_size and accumulator to control microbatching behavior. It specializes vmap by imposing stricter requirements on in_axes and out_axes.
- Example Usage:
>>> import optax >>> import jax.numpy as jnp >>> optax.microbatching.micro_vmap(lambda x: x**2)(jnp.arange(8)) Array([ 0, 1, 4, 9, 16, 25, 36, 49], dtype=int32)
- Parameters:
fun – Function to be mapped over additional axes.
in_axes – Array axis to map over. See jax.vmap for more details.
out_axes – Unsupported by optax.vmap, must be set to 0.
microbatch_size – The number of rows in the overall batch used in each microbatch. Smaller values reduces memory overhead, but require more sequential computation. This must evenly divide the batch axis size of the batch arguments.
vmap_fn – A function with the same signature as jax.vmap. Can be used to e.g., pass in kwargs to vmap.
accumulator – Specifies what to do with the vmapped outputs. The default value (CONCAT) returns each output with a batch axis, matching the behavior of jax.vmap. Reductions over the batch axis are also possible, including MEAN and SUM, and can be used when the the full output with a batch axis is not needed and is too large to fit in memory. This accumulator can be any PyTree prefix of the outputs of fun to apply different reductions to different sub-trees.
num_real_microbatches – Optional number of microbatches that are actually executed. If specified, microbatching will terminate early after this many steps. Can be helpful to handle variable batch sizes without recompilation.
- Returns:
A new function with the same args and kwargs having an additional batch axis (according to in_axes).
- optax.microbatching.micro_grad(fun: ~typing.Callable[[...], ~typing.Any], has_aux: bool = False, argnums: int | ~typing.Sequence[int] = 0, *, batch_argnums: int | ~typing.Sequence[int] = 1, keep_batch_dim: bool = True, microbatch_size: int | None = None, accumulator: ~optax.microbatching._microbatching.Accumulator | ~optax.microbatching._microbatching.AccumulationType | ~typing.Any = AccumulationType.SUM, transform_fn: ~typing.Callable[[TypeAliasForwardRef('optax.ArrayTree')], TypeAliasForwardRef('optax.ArrayTree')] = <function <lambda>>, metrics_fn: ~typing.Callable[[TypeAliasForwardRef('optax.ArrayTree')], TypeAliasForwardRef('optax.ArrayTree')] = <function <lambda>>, num_real_microbatches: int | ~jax.jaxlib._jax.Array | None = None) Callable[[...], tuple[Any, Aux]][source]#
Create a function to compute, transform, and sum per-example gradients.
This function is similar to jax.value_and_grad, but works at the level of size-1 batches. This function is defined in terms of general transformations transform_fn and metrics_fn which can be useful to e.g., * limit the effect of outlier batch elements by clipping per-example grads. * compute moments of the gradients on a per-example basis. * computing scalar or low-dimensional gradient metrics on a per-example basis.
Other notable differences between this function and jax.value_and_grad: * at least one argument to fun must have a batch axis, and that argument
should be passed to batch_argnums. The default value of 1 assumes that fun has the signature fun(params, batch, …).
The return signature is different. The gradient is always returned as the first output, while all auxiliary outputs are returned as a namedtuple in the second output (including values, function aux, and metrics).
This function may be able to work for far larger batch sizes than native jax.value_and_grad due to the built-in microbatching.
- Example Usage (see https://arxiv.org/abs/2510.00236):
>>> import optax >>> def mean_squared_loss(params, features, targets): ... preds = features @ params ... diff = preds - targets ... return 0.5 * jnp.mean(diff**2) >>> params = jnp.zeros(1) >>> features = jnp.ones((4, 1)) >>> targets = jnp.array([0, 2, 4, 6]) >>> (grads, squared_grads), aux = optax.microbatching.micro_grad( ... mean_squared_loss, ... argnums=0, ... batch_argnums=(1,2), ... accumulator=optax.microbatching.AccumulationType.MEAN, ... transform_fn=lambda x: (x, x**2), ... metrics_fn=jnp.linalg.norm ... )(params, features, targets) >>> grads, squared_grads # per-example grads are [0, 2, 4, 6] (Array([-3.], dtype=float32), Array([14.], dtype=float32)) >>> aux.values Array([ 0., 2., 8., 18.], dtype=float32) >>> aux.metrics Array([0., 2., 4., 6.], dtype=float32)
- Parameters:
fun – The function to compute the gradient of.
has_aux – Whether the function returns auxiliary output.
argnums – The indices of argument(s) to differentiate with respect to.
batch_argnums – The indices of argument(s) with a batch axis.
keep_batch_dim – Whether fun expects inputs to have a batch dimension.
microbatch_size – The size of the microbatches to use when computing the per-example gradients. See microbatch for more details.
accumulator – Specifies how to combine or aggregate the transformed gradients across the batch axis.
transform_fn – A function to apply to per-example gradients before averaging.
metrics_fn – A function to apply to per-example gradients before transforming. Will be returned on a per-example basis as part of the auxiliary output, and therefore should be scalar or low-dimensional.
num_real_microbatches – Optional number of microbatches that are actually executed. If specified, microbatching will terminate early after this many steps. Can be helpful to handle variable batch sizes without recompilation.
- Returns:
A function that computes the value and gradient of fun, averaging the results over microbatches and applying the transform_fn and metrics_fn as described above. The auxiliary output (including values, metrics, function aux) will all be returned on a per-example-basis.
- optax.microbatching.reshape_batch_axis(tree: Any, microbatch_size: int, axis: int = 0) Any[source]#
Reshape batch axis of pytree leaves for use with microbatching.
This function reshapes the batch axis of each leaf into a shape (num_microbatches, microbatch_size) appearing at the same axis as the original batch axis. The reshape is done using a column-major order, so any sharding along the batch axis should be preserved in the new microbatch_size axis, while the new num_microbatches axis will generally be replicated.
- Parameters:
tree – A pytree of jax.Arrays, each having a batch axis.
microbatch_size – The size of sub-batches used for each microbatch.
axis – The axis to reshape.
- Returns:
A pytree of reshaped jax.Arrays.
- class optax.microbatching.AccumulationType(value)[source]#
The type of accumulation to perform.
- MEAN = 1#
Average the microbatch outputs.
- SUM = 2#
Sum the microbatch outputs.
- RUNNING_MEAN = 3#
Average the microbatch outputs over num_real_microbatches.
- CONCAT = 4#
Concatenate the microbatch outputs along axis 0.
- class optax.microbatching.Accumulator(init: PyTreeFn, update: UpdateFn, finalize: PyTreeFn, aggregate: PyTreeFn)[source]#
A class for accumulating values in a microbatched function.
Given a list of microbatch function evaluations [x_0, …, x_{n-1}], this object represents the program.
carry = init(jax.typeof(x_0)) for i in range(n):
carry = update(carry, x_i, i)
return finalize(carry)
- init#
A function f(shape_dtype_struct) that initializes the microbatch state from the shape/dtype of a single microbatch evaluation.
- Type:
PyTreeFn
- update#
A function f(carry, value, index) that updates the microbatch state with the function evaluation of the current microbatch.
- Type:
UpdateFn
- finalize#
A function f(carry) that returns the final result from the final state.
- Type:
PyTreeFn
- aggregate#
A function f(per_microbatch_value) that aggregates per-microbatch values into a single value. Used by micro_vmap.
- Type:
PyTreeFn
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(init: PyTreeFn, update: UpdateFn, finalize: PyTreeFn, aggregate: PyTreeFn) None#
- __setattr__(name, value)#
Implement setattr(self, name, value).