optax.contrib.ScaleBySimplifiedAdEMAMixState#
- class optax.contrib.ScaleBySimplifiedAdEMAMixState(t: jax.typing.ArrayLike, m: optax.Updates, n: optax.Updates)[source]#
State for the Simplified AdEMAMix optimizer.
- t#
iteration count
- Type:
jax.typing.ArrayLike
- m#
EMA
- Type:
base.Updates
- n#
second moment estimate
- Type:
base.Updates