optax.EmaState# class optax.EmaState(count: jax.typing.ArrayLike, ema: optax.Params)[source]# Holds an exponential moving average of past updates.