optax.contrib.MechanicState#
- class optax.contrib.MechanicState(base_optimizer_state: optax.OptState, count: jax.typing.ArrayLike, r: jax.typing.ArrayLike, m: jax.typing.ArrayLike, v: jax.typing.ArrayLike, s: jax.typing.ArrayLike, x0: optax.Updates)[source]#
State of the GradientTransformation returned by mechanize.