optax.adagrad

Contents

optax.adagrad#

optax.adagrad(learning_rate: base.ScalarOrSchedule, initial_accumulator_value: jax.typing.ArrayLike = 0.1, eps: jax.typing.ArrayLike = 1e-07) base.GradientTransformationExtraArgs[source]#

The Adagrad optimizer.

AdaGrad is a sub-gradient algorithm for stochastic optimization that adapts the learning rate individually for each feature based on its gradient history.

The updated parameters adopt the form:

\[w_{t+1}^{(i)} = w_{t}^{(i)} - \eta \frac{g_{t}^{(i)}} {\sqrt{\sum_{\tau=1}^{t} (g_{\tau}^{(i)})^2 + \epsilon}}\]
where:
  • \(w_t^{(i)}\) is the parameter \(i\) at time step \(t\),

  • \(\eta\) is the learning rate,

  • \(g_t^{(i)}\) is the gradient of parameter \(i\) at time step \(t\),

  • \(\epsilon\) is a small constant to ensure numerical stability.

Defining \(G = \sum_{t=1}^\tau g_t g_t^\top\), the update can be written as

\[w_{t+1} = w_{t} - \eta \cdot \text{diag}(G + \epsilon I)^{-1/2} \cdot g_t\]

where \(\text{diag} (G) = (G_{ii})_{i=1}^p\) is the vector of diagonal entries of \(G \in \mathbb{R}^p\) and \(I\) is the identity matrix in \(\mathbb{R}^p\).

Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • initial_accumulator_value โ€“ Initial value for the accumulator.

  • eps โ€“ A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adagrad(learning_rate=1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 5.01E+00
Objective function: 2.40E+00
Objective function: 1.25E+00
Objective function: 6.86E-01
Objective function: 3.85E-01

References

Duchi et al, Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, 2011

Warning

Adagradโ€™s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.