optax.scale_by_radam

Contents

optax.scale_by_radam#

optax.scale_by_radam(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, threshold: jax.typing.ArrayLike = 5.0, *, nesterov: bool = False) optax.GradientTransformation[source]#

Rescale updates according to the Rectified Adam algorithm.

See optax.radam() for more details.

Parameters:
  • b1 โ€“ Decay rate for the exponentially weighted average of grads.

  • b2 โ€“ Decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • eps_root โ€“ Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • threshold โ€“ Threshold for variance tractability.

  • nesterov โ€“ Whether to use Nesterov momentum.

Returns:

A optax.GradientTransformation object.