optax.contrib.MomoState#
- class optax.contrib.MomoState(exp_avg: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the GradientTransformation returned by momo.