optax.polyak_sgd

Contents

optax.polyak_sgd#

optax.polyak_sgd(max_learning_rate: jax.typing.ArrayLike = 1.0, scaling: base.ScalarOrSchedule = 1.0, f_min: jax.typing.ArrayLike = 0.0, eps: jax.typing.ArrayLike = 0.0, variant: str = 'sps') base.GradientTransformationExtraArgs[source]#

SGD with Polyak step-size.

This solver implements the SGD with Polyak step size of (Loizou et al. 2021). It sets the step-size as

\[s \min\left\{\frac{f(x) - f^\star}{\|\nabla f(x)\|^2 + \epsilon}, \gamma_{\max}\right\}\,, \]

where \(f\) is the function from which a gradient is computed, \(\gamma_{\max}\) is a maximal acceptable learning rate set by max_learning_rate, \(\epsilon\) is a constant preventing division by zero set with eps, \(s\) scales the formula by scaling, and \(f^\star\) is a guess of the minimum value of the function set with f_min.

Setting variant="sps+" (Garrigos et al. 2023) uses only the non-negative part of the suboptimality gap. That is, it replaces \(f(x) - f^\star\) with \((f(x) - f^\star)_+\), where \(a_+ = \max \{x, 0\}\).

Parameters:
  • max_learning_rate โ€“ a maximum step size to use (defaults to 1).

  • scaling โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler (defaults to 1).

  • f_min โ€“ a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.

  • eps โ€“ a value to add in the denominator of the update (defaults to 0).

  • variant โ€“ either 'sps' or 'sps+' (defaults to 'sps').

Returns:

A optax.GradientTransformationExtraArgs, where the update functiontakes an additional keyword argument value containing the current value of the objective function.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.polyak_sgd()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  3.5
Objective function:  0.875
Objective function:  0.21875
Objective function:  0.0546875
Objective function:  0.013671875

References

Loizou et al. Stochastic polyak step-size for SGD: An adaptive learning rate for fast convergence, 2021

Berrada et al., Training neural networks for and by interpolation, 2020

Garrigos et al., Function value learning: Adaptive learning rates based on the Polyak stepsize and function splitting in ERM, 2023

Warning

This method requires knowledge of an approximate value of the objective function minimum, passed through the f_min argument. For models that interpolate the data, this can be set to 0 (default value). Failing to set an appropriate value for f_min can lead to divergence or convergence to a suboptimal solution.