optax.EmaState

Contents

optax.EmaState#

class optax.EmaState(count: jax.typing.ArrayLike, ema: optax.Params)[source]#

Holds an exponential moving average of past updates.