optax.MultiSteps

Contents

optax.MultiSteps#

class optax.MultiSteps(opt: optax.GradientTransformation, every_k_schedule: int | ~collections.abc.Callable[[TypeAliasForwardRef('jax.typing.ArrayLike')], TypeAliasForwardRef('jax.typing.ArrayLike')], use_grad_mean: bool = True, should_skip_update_fn: ~optax.transforms._accumulation.ShouldSkipUpdateFunction | None = None, accumulator_dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'jax.numpy.float32'>)[source]#

An optimizer wrapper to accumulate gradients over multiple steps.

This wrapper collects together the updates passed to its update function over consecutive steps until a given number of scheduled steps is reached. In each of these intermediate steps, the returned value from the optimizer is a tree of zeros of the same shape of the updates passed as input.

Once the scheduled number of intermediate ‘mini-steps’ has been reached, the gradients accumulated to the current time will be passed to the wrapped optimizer’s update function, (with the inner optimizer’s state being updated appropriately) and then returned to the caller. The wrapper’s accumulated gradients are then set back to zero and the process starts again.

The number of mini-steps per gradient update is controlled by a function, and can vary over training, this also allows varying batch size over training.