optax.contrib.ScaleBySimplifiedAdEMAMixState#

class optax.contrib.ScaleBySimplifiedAdEMAMixState(t: jax.typing.ArrayLike, m: optax.Updates, n: optax.Updates)[source]#

State for the Simplified AdEMAMix optimizer.

t#

iteration count

Type:

jax.typing.ArrayLike

m#

EMA

Type:

base.Updates

n#

second moment estimate

Type:

base.Updates