optax.MultiStepsState#
- class optax.MultiStepsState(mini_step: jax.typing.ArrayLike, gradient_step: jax.typing.ArrayLike, inner_opt_state: Any, acc_grads: Any, skip_state: base.ArrayTree = ())[source]#
State of the GradientTransformation returned by MultiSteps.
- mini_step#
current mini-step counter. At an update, this either increases by 1 or is reset to 0.
- Type:
jax.typing.ArrayLike
- gradient_step#
gradient step counter. This only increases after enough mini-steps have been accumulated.
- Type:
jax.typing.ArrayLike
- inner_opt_state#
the state of the wrapped optimizer.
- Type:
Any
- acc_grads#
accumulated gradients over multiple mini-steps.
- Type:
Any
- skip_state#
an arbitrarily py tree. This is only relevant when passing a should_skip_update_fn to MultiSteps.
- Type:
base.ArrayTree