optax.yogi#
- optax.yogi(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 0.001) base.GradientTransformationExtraArgs[source]#
The Yogi optimizer.
Yogi is an adaptive optimizer, which provides control in tuning the effective learning rate to prevent it from increasing. By doing so, it focuses on addressing the issues of convergence and generalization in exponential moving average-based adaptive methods (such as Adam and RMSprop). Yogi is a modification of Adam and uses the same parameters.
- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 โ Exponential decay rate to track the first moment of past gradients.
b2 โ Exponential decay rate to track the second moment of past gradients.
eps โ A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.yogi(learning_rate=0.002) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01
References
Zaheer et al, Adaptive Methods for Nonconvex Optimization, 2018