optax.ema#
- optax.ema(decay: jax.typing.ArrayLike, debias: bool = True, accumulator_dtype: Any | None = None) optax.GradientTransformation[source]#
Compute an exponential moving average of past updates.
- Parameters:
decay โ Decay rate for the exponential moving average.
debias โ Whether to debias the transformed gradient.
accumulator_dtype โ Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.
- Returns:
A
optax.GradientTransformationobject.
Note
optax.trace()andoptax.ema()have very similar but distinct updates;trace = decay * trace + t, whileema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.