optax.measure_with_ema#
- optax.measure_with_ema(measure: Callable[[base.Updates], base.ArrayTree], decay: jax.typing.ArrayLike, debias: bool = True, accumulator_dtype: Any | None = None) base.GradientTransformationExtraArgs[source]#
Take a measurement and record it with exponential moving average.
- Parameters:
measure โ User callable taking as inputs updates and returning desired measurement.
decay โ Decay rate for the exponential moving average.
debias โ Whether to debias the exponential moving average.
accumulator_dtype โ Optional dtype for the exponential moving average accumulator.
- Returns:
A gradient transformation that captures measurements defined by the user, and records them with exponential moving average.
See also