optax.contrib.MomoAdamState#
- class optax.contrib.MomoAdamState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the
GradientTransformationreturned bymomo_adam.