optax.trace#
- optax.trace(decay: jax.typing.ArrayLike, nesterov: bool = False, accumulator_dtype: Any | None = None) optax.GradientTransformation[source]#
Compute a trace of past updates.
- Parameters:
decay โ Decay rate for the trace of past updates.
nesterov โ Whether to use Nesterov momentum.
accumulator_dtype โ Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.
- Returns:
A
optax.GradientTransformationobject.
Note
optax.trace()andoptax.ema()have very similar but distinct updates;trace = decay * trace + t, whileema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.