optax.contrib.MomoState

Contents

optax.contrib.MomoState#

class optax.contrib.MomoState(exp_avg: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by momo.