optax.contrib.ScaleByAdemamixState#

class optax.contrib.ScaleByAdemamixState(count: jax.typing.ArrayLike, count_m2: jax.typing.ArrayLike, m1: optax.Updates, m2: optax.Updates, nu: optax.Updates)[source]#

State for the Ademamix algorithm.

count#

iteration of the algorithm used to update the fast EMA and second moment.

Type:

jax.typing.ArrayLike

count_m2#

iteration of the algorithm used to update the slow EMA and alpha.

Type:

jax.typing.ArrayLike

m1#

fast EMA of the first moment

Type:

base.Updates

m2#

slow EMA of the first moment

Type:

base.Updates

nu#

estimate of the second moment

Type:

base.Updates