optax.ema

Contents

optax.ema#

optax.ema(decay: jax.typing.ArrayLike, debias: bool = True, accumulator_dtype: Any | None = None) optax.GradientTransformation[source]#

Compute an exponential moving average of past updates.

Parameters:
  • decay โ€“ Decay rate for the exponential moving average.

  • debias โ€“ Whether to debias the transformed gradient.

  • accumulator_dtype โ€“ Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.

Returns:

A optax.GradientTransformation object.

Note

optax.trace() and optax.ema() have very similar but distinct updates; trace = decay * trace + t, while ema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.