optax.scale_by_radam#
- optax.scale_by_radam(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, threshold: jax.typing.ArrayLike = 5.0, *, nesterov: bool = False) optax.GradientTransformation[source]#
Rescale updates according to the Rectified Adam algorithm.
See
optax.radam()for more details.- Parameters:
b1 โ Decay rate for the exponentially weighted average of grads.
b2 โ Decay rate for the exponentially weighted average of squared grads.
eps โ Term added to the denominator to improve numerical stability.
eps_root โ Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
threshold โ Threshold for variance tractability.
nesterov โ Whether to use Nesterov momentum.
- Returns:
A
optax.GradientTransformationobject.