Microbatching#
|
A general microbatching transformation. |
|
A generalized version of jax.vmap that supports microbatching. |
|
Create a function to compute, transform, and sum per-example gradients. |
|
Reshape batch axis of pytree leaves for use with microbatching. |
|
The type of accumulation to perform. |
|
A class for accumulating values in a microbatched function. |