optax.adadelta#
- optax.adadelta(learning_rate: base.ScalarOrSchedule | None = None, rho: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-06, weight_decay: jax.typing.ArrayLike | base.ScalarOrSchedule = 0.0, weight_decay_mask: MaskOrFn = None) base.GradientTransformationExtraArgs[source]#
The Adadelta optimizer.
Adadelta is a stochastic gradient descent method that adapts learning rates based on a moving window of gradient updates. Adadelta is a modification of Adagrad. It addresses the diminishing learning rates problem in Adagrad by maintaining running averages of squared gradients.
The weight update \(\Delta w_t\) for this optimizer is given as follows:
\[\begin{align*} &E[g^2]_t = \rho \cdot E[g^2]_{t-1} + (1-\rho) \cdot g_t^2 \\ &\Delta w_t = -\frac{\sqrt{E[\Delta w^2]_{t-1} + \epsilon}}{\sqrt{E[g^2]_t + \epsilon}} \cdot g_t \end{align*}\]- where:
\(g_t\) is the gradient at time step \(t\),
\(E[g^2]_t\) is the running average of squared gradients,
\(E[\Delta w^2]_t\) is the running average of squared parameter updates,
\(\rho\) is the decay rate (typically 0.9),
\(\epsilon\) is a small constant for numerical stability.
- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().rho โ A coefficient used for computing a running average of squared gradients.
eps โ Term added to the denominator to improve numerical stability.
weight_decay โ Optional rate at which to decay weights.
weight_decay_mask โ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> f = lambda x: jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adadelta(learning_rate=10.) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.36E+01 Objective function: 1.32E+01 Objective function: 1.29E+01 Objective function: 1.25E+01 Objective function: 1.21E+01
References
Zeiler, Adadelta: An Adaptive Learning Rate Optimizer, 2012