optax.trace

Contents

optax.trace#

optax.trace(decay: jax.typing.ArrayLike, nesterov: bool = False, accumulator_dtype: Any | None = None) optax.GradientTransformation[source]#

Compute a trace of past updates.

Parameters:
  • decay โ€“ Decay rate for the trace of past updates.

  • nesterov โ€“ Whether to use Nesterov momentum.

  • accumulator_dtype โ€“ Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Returns:

A optax.GradientTransformation object.

Note

optax.trace() and optax.ema() have very similar but distinct updates; trace = decay * trace + t, while ema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.