optax.microbatching.Accumulator#

class optax.microbatching.Accumulator(init: PyTreeFn, update: UpdateFn, finalize: PyTreeFn, aggregate: PyTreeFn)[source]#

A class for accumulating values in a microbatched function.

Given a list of microbatch function evaluations [x_0, โ€ฆ, x_{n-1}], this object represents the program.


carry = init(jax.typeof(x_0)) for i in range(n):

carry = update(carry, x_i, i)

return finalize(carry)

init#

A function f(shape_dtype_struct) that initializes the microbatch state from the shape/dtype of a single microbatch evaluation.

Type:

PyTreeFn

update#

A function f(carry, value, index) that updates the microbatch state with the function evaluation of the current microbatch.

Type:

UpdateFn

finalize#

A function f(carry) that returns the final result from the final state.

Type:

PyTreeFn

aggregate#

A function f(per_microbatch_value) that aggregates per-microbatch values into a single value. Used by micro_vmap.

Type:

PyTreeFn