optax.microbatching.reshape_batch_axis

optax.microbatching.reshape_batch_axis#

optax.microbatching.reshape_batch_axis(tree: Any, microbatch_size: int, axis: int = 0) Any[source]#

Reshape batch axis of pytree leaves for use with microbatching.

This function reshapes the batch axis of each leaf into a shape (num_microbatches, microbatch_size) appearing at the same axis as the original batch axis. The reshape is done using a column-major order, so any sharding along the batch axis should be preserved in the new microbatch_size axis, while the new num_microbatches axis will generally be replicated.

Parameters:
  • tree โ€“ A pytree of jax.Arrays, each having a batch axis.

  • microbatch_size โ€“ The size of sub-batches used for each microbatch.

  • axis โ€“ The axis to reshape.

Returns:

A pytree of reshaped jax.Arrays.