optax.measure_with_ema

optax.measure_with_ema#

optax.measure_with_ema(measure: Callable[[base.Updates], base.ArrayTree], decay: jax.typing.ArrayLike, debias: bool = True, accumulator_dtype: Any | None = None) base.GradientTransformationExtraArgs[source]#

Take a measurement and record it with exponential moving average.

Parameters:
  • measure โ€“ User callable taking as inputs updates and returning desired measurement.

  • decay โ€“ Decay rate for the exponential moving average.

  • debias โ€“ Whether to debias the exponential moving average.

  • accumulator_dtype โ€“ Optional dtype for the exponential moving average accumulator.

Returns:

A gradient transformation that captures measurements defined by the user, and records them with exponential moving average.

See also

optax.monitor()