optax.radam

Contents

optax.radam#

optax.radam(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, threshold: jax.typing.ArrayLike = 5.0, *, nesterov: bool = False) base.GradientTransformationExtraArgs[source]#

The Rectified Adam optimizer.

The adaptive learning rate in Adam has undesirably large variance in early stages of training, due to the limited number of training samples used to estimate the optimizer’s statistics. Rectified Adam addresses this issue by analytically reducing the large variance.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • threshold – Threshold for variance tractability.

  • nesterov – Whether to use Nesterov momentum.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.radam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References

Liu et al, 2020: On the Variance of the Adaptive Learning Rate and Beyond, 2020