optax.MultiStepsState#

class optax.MultiStepsState(mini_step: jax.typing.ArrayLike, gradient_step: jax.typing.ArrayLike, inner_opt_state: Any, acc_grads: Any, skip_state: base.ArrayTree = ())[source]#

State of the GradientTransformation returned by MultiSteps.

mini_step#

current mini-step counter. At an update, this either increases by 1 or is reset to 0.

Type:

jax.typing.ArrayLike

gradient_step#

gradient step counter. This only increases after enough mini-steps have been accumulated.

Type:

jax.typing.ArrayLike

inner_opt_state#

the state of the wrapped optimizer.

Type:

Any

acc_grads#

accumulated gradients over multiple mini-steps.

Type:

Any

skip_state#

an arbitrarily py tree. This is only relevant when passing a should_skip_update_fn to MultiSteps.

Type:

base.ArrayTree