optax.polyak_sgd#
- optax.polyak_sgd(max_learning_rate: jax.typing.ArrayLike = 1.0, scaling: base.ScalarOrSchedule = 1.0, f_min: jax.typing.ArrayLike = 0.0, eps: jax.typing.ArrayLike = 0.0, variant: str = 'sps') base.GradientTransformationExtraArgs[source]#
SGD with Polyak step-size.
This solver implements the SGD with Polyak step size of (Loizou et al. 2021). It sets the step-size as
\[s \min\left\{\frac{f(x) - f^\star}{\|\nabla f(x)\|^2 + \epsilon}, \gamma_{\max}\right\}\,, \]where \(f\) is the function from which a gradient is computed, \(\gamma_{\max}\) is a maximal acceptable learning rate set by
max_learning_rate, \(\epsilon\) is a constant preventing division by zero set witheps, \(s\) scales the formula byscaling, and \(f^\star\) is a guess of the minimum value of the function set withf_min.Setting
variant="sps+"(Garrigos et al. 2023) uses only the non-negative part of the suboptimality gap. That is, it replaces \(f(x) - f^\star\) with \((f(x) - f^\star)_+\), where \(a_+ = \max \{x, 0\}\).- Parameters:
max_learning_rate โ a maximum step size to use (defaults to 1).
scaling โ A global scaling factor, either fixed or evolving along iterations with a scheduler (defaults to 1).
f_min โ a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.
eps โ a value to add in the denominator of the update (defaults to 0).
variant โ either
'sps'or'sps+'(defaults to'sps').
- Returns:
A
optax.GradientTransformationExtraArgs, where theupdatefunctiontakes an additional keyword argumentvaluecontaining the current value of the objective function.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.polyak_sgd() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 3.5 Objective function: 0.875 Objective function: 0.21875 Objective function: 0.0546875 Objective function: 0.013671875
References
Loizou et al. Stochastic polyak step-size for SGD: An adaptive learning rate for fast convergence, 2021
Berrada et al., Training neural networks for and by interpolation, 2020
Garrigos et al., Function value learning: Adaptive learning rates based on the Polyak stepsize and function splitting in ERM, 2023
Warning
This method requires knowledge of an approximate value of the objective function minimum, passed through the
f_minargument. For models that interpolate the data, this can be set to 0 (default value). Failing to set an appropriate value forf_mincan lead to divergence or convergence to a suboptimal solution.