optax.adadelta

Contents

optax.adadelta#

optax.adadelta(learning_rate: base.ScalarOrSchedule | None = None, rho: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-06, weight_decay: jax.typing.ArrayLike | base.ScalarOrSchedule = 0.0, weight_decay_mask: MaskOrFn = None) base.GradientTransformationExtraArgs[source]#

The Adadelta optimizer.

Adadelta is a stochastic gradient descent method that adapts learning rates based on a moving window of gradient updates. Adadelta is a modification of Adagrad. It addresses the diminishing learning rates problem in Adagrad by maintaining running averages of squared gradients.

The weight update \(\Delta w_t\) for this optimizer is given as follows:

\[\begin{align*} &E[g^2]_t = \rho \cdot E[g^2]_{t-1} + (1-\rho) \cdot g_t^2 \\ &\Delta w_t = -\frac{\sqrt{E[\Delta w^2]_{t-1} + \epsilon}}{\sqrt{E[g^2]_t + \epsilon}} \cdot g_t \end{align*}\]
where:
  • \(g_t\) is the gradient at time step \(t\),

  • \(E[g^2]_t\) is the running average of squared gradients,

  • \(E[\Delta w^2]_t\) is the running average of squared parameter updates,

  • \(\rho\) is the decay rate (typically 0.9),

  • \(\epsilon\) is a small constant for numerical stability.

Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • rho โ€“ A coefficient used for computing a running average of squared gradients.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • weight_decay โ€“ Optional rate at which to decay weights.

  • weight_decay_mask โ€“ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adadelta(learning_rate=10.)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.36E+01
Objective function: 1.32E+01
Objective function: 1.29E+01
Objective function: 1.25E+01
Objective function: 1.21E+01

References

Zeiler, Adadelta: An Adaptive Learning Rate Optimizer, 2012