optax.contrib.MomoAdamState

Contents

optax.contrib.MomoAdamState#

class optax.contrib.MomoAdamState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by momo_adam.