optax.adagrad#
- optax.adagrad(learning_rate: base.ScalarOrSchedule, initial_accumulator_value: jax.typing.ArrayLike = 0.1, eps: jax.typing.ArrayLike = 1e-07) base.GradientTransformationExtraArgs[source]#
The Adagrad optimizer.
AdaGrad is a sub-gradient algorithm for stochastic optimization that adapts the learning rate individually for each feature based on its gradient history.
The updated parameters adopt the form:
\[w_{t+1}^{(i)} = w_{t}^{(i)} - \eta \frac{g_{t}^{(i)}} {\sqrt{\sum_{\tau=1}^{t} (g_{\tau}^{(i)})^2 + \epsilon}}\]- where:
\(w_t^{(i)}\) is the parameter \(i\) at time step \(t\),
\(\eta\) is the learning rate,
\(g_t^{(i)}\) is the gradient of parameter \(i\) at time step \(t\),
\(\epsilon\) is a small constant to ensure numerical stability.
Defining \(G = \sum_{t=1}^\tau g_t g_t^\top\), the update can be written as
\[w_{t+1} = w_{t} - \eta \cdot \text{diag}(G + \epsilon I)^{-1/2} \cdot g_t\]where \(\text{diag} (G) = (G_{ii})_{i=1}^p\) is the vector of diagonal entries of \(G \in \mathbb{R}^p\) and \(I\) is the identity matrix in \(\mathbb{R}^p\).
- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().initial_accumulator_value โ Initial value for the accumulator.
eps โ A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adagrad(learning_rate=1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 5.01E+00 Objective function: 2.40E+00 Objective function: 1.25E+00 Objective function: 6.86E-01 Objective function: 3.85E-01
References
Duchi et al, Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, 2011
Warning
Adagradโs main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.