Microbatching

Microbatching#

microbatch(fun, argnums, microbatch_size[, ...])

A general microbatching transformation.

micro_vmap(fun, ~typing.Any], in_axes, ...)

A generalized version of jax.vmap that supports microbatching.

micro_grad(fun, ~typing.Any], has_aux, ...)

Create a function to compute, transform, and sum per-example gradients.

reshape_batch_axis(tree, microbatch_size[, axis])

Reshape batch axis of pytree leaves for use with microbatching.

AccumulationType(value)

The type of accumulation to perform.

Accumulator(init, update, finalize, aggregate)

A class for accumulating values in a microbatched function.