optax.contrib.ScaleByAdemamixState#
- class optax.contrib.ScaleByAdemamixState(count: jax.typing.ArrayLike, count_m2: jax.typing.ArrayLike, m1: optax.Updates, m2: optax.Updates, nu: optax.Updates)[source]#
State for the Ademamix algorithm.
- count#
iteration of the algorithm used to update the fast EMA and second moment.
- Type:
jax.typing.ArrayLike
- count_m2#
iteration of the algorithm used to update the slow EMA and alpha.
- Type:
jax.typing.ArrayLike
- m1#
fast EMA of the first moment
- Type:
base.Updates
- m2#
slow EMA of the first moment
- Type:
base.Updates
- nu#
estimate of the second moment
- Type:
base.Updates