optax.microbatching.micro_grad

Contents

optax.microbatching.micro_grad#

optax.microbatching.micro_grad(fun: ~typing.Callable[[...], ~typing.Any], has_aux: bool = False, argnums: int | ~typing.Sequence[int] = 0, *, batch_argnums: int | ~typing.Sequence[int] = 1, keep_batch_dim: bool = True, microbatch_size: int | None = None, accumulator: ~optax.microbatching._microbatching.Accumulator | ~optax.microbatching._microbatching.AccumulationType | ~typing.Any = AccumulationType.SUM, transform_fn: ~typing.Callable[[TypeAliasForwardRef('optax.ArrayTree')], TypeAliasForwardRef('optax.ArrayTree')] = <function <lambda>>, metrics_fn: ~typing.Callable[[TypeAliasForwardRef('optax.ArrayTree')], TypeAliasForwardRef('optax.ArrayTree')] = <function <lambda>>, num_real_microbatches: int | ~jax.jaxlib._jax.Array | None = None) Callable[[...], tuple[Any, Aux]][source]#

Create a function to compute, transform, and sum per-example gradients.

This function is similar to jax.value_and_grad, but works at the level of size-1 batches. This function is defined in terms of general transformations transform_fn and metrics_fn which can be useful to e.g.,

  • limit the effect of outlier batch elements by clipping per-example grads.

  • compute moments of the gradients on a per-example basis.

  • computing scalar or low-dimensional gradient metrics on a per-example basis.

Other notable differences between this function and jax.value_and_grad:

  • at least one argument to fun must have a batch axis, and that argument should be passed to batch_argnums. The default value of 1 assumes that fun has the signature fun(params, batch, โ€ฆ).

  • The return signature is different. The gradient is always returned as the first output, while all auxiliary outputs are returned as a namedtuple in the second output (including values, function aux, and metrics).

  • This function may be able to work for far larger batch sizes than native jax.value_and_grad due to the built-in microbatching.

Example Usage (see https://arxiv.org/abs/2510.00236):
>>> import optax
>>> def mean_squared_loss(params, features, targets):
...   preds = features @ params
...   diff = preds - targets
...   return 0.5 * jnp.mean(diff**2)
>>> params = jnp.zeros(1)
>>> features = jnp.ones((4, 1))
>>> targets = jnp.array([0, 2, 4, 6])
>>> (grads, squared_grads), aux = optax.microbatching.micro_grad(
...     mean_squared_loss,
...     argnums=0,
...     batch_argnums=(1,2),
...     accumulator=optax.microbatching.AccumulationType.MEAN,
...     transform_fn=lambda x: (x, x**2),
...     metrics_fn=jnp.linalg.norm
... )(params, features, targets)
>>> grads, squared_grads  # per-example grads are [0, 2, 4, 6]
(Array([-3.], dtype=float32), Array([14.], dtype=float32))
>>> aux.values
Array([ 0.,  2.,  8., 18.], dtype=float32)
>>> aux.metrics
Array([0., 2., 4., 6.], dtype=float32)
Parameters:
  • fun โ€“ The function to compute the gradient of.

  • has_aux โ€“ Whether the function returns auxiliary output.

  • argnums โ€“ The indices of argument(s) to differentiate with respect to.

  • batch_argnums โ€“ The indices of argument(s) with a batch axis.

  • keep_batch_dim โ€“ Whether fun expects inputs to have a batch dimension.

  • microbatch_size โ€“ The size of the microbatches to use when computing the per-example gradients. See microbatch for more details.

  • accumulator โ€“ Specifies how to combine or aggregate the transformed gradients across the batch axis.

  • transform_fn โ€“ A function to apply to per-example gradients before averaging.

  • metrics_fn โ€“ A function to apply to per-example gradients before transforming. Will be returned on a per-example basis as part of the auxiliary output, and therefore should be scalar or low-dimensional.

  • num_real_microbatches โ€“ Optional number of microbatches that are actually executed. If specified, microbatching will terminate early after this many steps. Can be helpful to handle variable batch sizes without recompilation.

Returns:

A function that computes the value and gradient of fun, averaging the results over microbatches and applying the transform_fn and metrics_fn as described above. The auxiliary output (including values, metrics, function aux) will all be returned on a per-example-basis.