optax.microbatching.Accumulator#
- class optax.microbatching.Accumulator(init: PyTreeFn, update: UpdateFn, finalize: PyTreeFn, aggregate: PyTreeFn)[source]#
A class for accumulating values in a microbatched function.
Given a list of microbatch function evaluations [x_0, โฆ, x_{n-1}], this object represents the program.
carry = init(jax.typeof(x_0)) for i in range(n):
carry = update(carry, x_i, i)
return finalize(carry)
- init#
A function f(shape_dtype_struct) that initializes the microbatch state from the shape/dtype of a single microbatch evaluation.
- Type:
PyTreeFn
- update#
A function f(carry, value, index) that updates the microbatch state with the function evaluation of the current microbatch.
- Type:
UpdateFn
- finalize#
A function f(carry) that returns the final result from the final state.
- Type:
PyTreeFn
- aggregate#
A function f(per_microbatch_value) that aggregates per-microbatch values into a single value. Used by micro_vmap.
- Type:
PyTreeFn