optax.scale_by_zoom_linesearch

optax.scale_by_zoom_linesearch#

optax.scale_by_zoom_linesearch(max_linesearch_steps: jax.typing.ArrayLike, max_learning_rate: jax.typing.ArrayLike | None = None, tol: jax.typing.ArrayLike = 0.0, increase_factor: jax.typing.ArrayLike = 2.0, slope_rtol: jax.typing.ArrayLike = 0.0001, curv_rtol: jax.typing.ArrayLike = 0.9, approx_dec_rtol: jax.typing.ArrayLike | None = 1e-06, stepsize_precision: jax.typing.ArrayLike = 1e-05, initial_guess_strategy: str = 'keep', verbose: bool = False) base.GradientTransformationExtraArgs[source]#

Linesearch ensuring sufficient decrease and small curvature.

This algorithm searches for a learning rate, a.k.a. stepsize, that satisfies both a sufficient decrease criterion, a.k.a. Armijo-Goldstein criterion,

\[f(w + \eta u) \leq f(w) + \eta c_1 \langle u, \nabla f(w) \rangle + \epsilon \,, \]

and a small curvature (along the update direction) criterion, a.k.a. Wolfe or second Wolfe criterion,

\[|\langle \nabla f(w + \eta u), u \rangle| \leq c_2 |\langle \nabla f(w), \rangle| + \epsilon\,, \]

where

  • \(f\) is the function to minimize,

  • \(w\) are the current parameters,

  • \(\eta\) is the learning rate to find,

  • \(u\) is the update direction,

  • \(c_1\) is a coefficient (slope_rtol) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates),

  • \(c_2\) is a coefficient (curv_rtol) measuring the relative decrease of curvature.

  • \(\epsilon\) is an absolute tolerance (tol).

To deal with very flat functions, this linesearch switches from the sufficient decrease criterion presented above to an approximate sufficient decrease criterion introduced by Hager and Zhang (see [Hager and Zhang, 2006]).

\[|\langle \nabla f(w+\eta u), u \rangle| \leq (2 c_1 - 1) |\langle \nabla f(w), \rangle| + \epsilon\,. \]

The approximate curvature criterion is taken only if the values tried by the linesearch fall below a relative decrease of the initial function, that is,

\[f(w + \eta u) \leq f(w) + c_3 |f(w)| \]

where \(c_3\) is a coefficient approx_dec_rtol measuring the relative decrease of the objective (see reference below and comments in the code for more details).

The original sufficient decrease criterion can only capture differences up to \(\sqrt{\varepsilon_{machine}}\) while the approximate sufficient decrease criterion can capture differences up to \(\varepsilon_{machine}\) (see [Hager and Zhang, 2006]). Note that this add-on is not part of the original implementation of [Nocedal and Wright, 1999] and can be removed by setting approx_dec_rtol to None.

Parameters:
  • max_linesearch_steps โ€“ maximum number of linesearch iterations.

  • max_learning_rate โ€“ maximum admissible learning rate. Can be set to None for no upper bound. A non None value may prevent the linesearch to find a learning rate satisfying the small curvature criterion, since the latter may require sufficiently large stepsizes.

  • tol โ€“ tolerance on the criterions.

  • increase_factor โ€“ increasing factor to augment the learning rate when searching for a valid interval containing a learning rate satisfying both criterions.

  • slope_rtol โ€“ relative tolerance for the slope in the sufficient decrease criterion.

  • curv_rtol โ€“ relative tolerance for the curvature in the small curvature criterion.

  • approx_dec_rtol โ€“ relative tolerance for the initial value in the approximate sufficient decrease criterion. Can be set to None to use only the original Armijo-Goldstein decrease criterion.

  • stepsize_precision โ€“ precision in the search of a stepsize satisfying both conditions. The algorithm proceeds with a bisection that refines an interval containing a stepsize satisfying both conditions. If that interval is reduced below stepsize_precision and a stepsize satisfying a sufficient decrease has been found, the algorithm selects that stepsize even if the curvature condition is not satisfied.

  • initial_guess_strategy โ€“ initial guess for the learning rate used to start the linesearch. Can be either one or keep. If one, the initial guess is set to 1. If keep, the initial guess is set to the learning rate of the previous step. We recommend to use keep if this linesearch is used in combination with SGD. We recommend to use one if this linesearch is used in combination with Newton methods or quasi-Newton methods such as L-BFGS.

  • verbose โ€“ whether to print additional debugging information in case the linesearch fails.

Returns:

A optax.GradientTransformationExtraArgs object consisting in an init and an update function.

Examples

An example on using the zoom line-search with SGD:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> solver = optax.chain(
...    optax.sgd(learning_rate=1.),
...    optax.scale_by_zoom_linesearch(max_linesearch_steps=15)
... )
>>> # Function with additional inputs other than params
>>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y)
>>> params = jnp.array([1., 2., 3.])
>>> opt_state = solver.init(params)
>>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.)
>>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,))
>>> opt_state = solver.init(params)
>>> print('Objective function: {:.2E}'.format(fn(params, x, y)))
Objective function: 5.00E+01
>>> for x, y in zip(xs, ys):
...   value, grad = jax.value_and_grad(fn)(params, x, y)
...   updates, opt_state = solver.update(
...       grad,
...       opt_state,
...       params,
...       value=value,
...       grad=grad,
...       value_fn=fn,
...       x=x,
...       y=y
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params, x, y)))
Objective function: 2.56E-13
Objective function: 2.84E-14
Objective function: 0.00E+00
Objective function: 0.00E+00
Objective function: 0.00E+00

A similar example, but with a non-stochastic function where we can reuse the value and the gradient computed at the end of the linesearch:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> # Function without extra arguments
>>> def fn(params): return jnp.sum(params ** 2)
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.chain(
...    optax.sgd(learning_rate=1.),
...    optax.scale_by_zoom_linesearch(max_linesearch_steps=15)
... )
>>> opt_state = solver.init(params)
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 0.00E+00
Objective function: 0.00E+00
Objective function: 0.00E+00
Objective function: 0.00E+00
Objective function: 0.00E+00

References

Algorithms 3.5 3.6 of Nocedal and Wright, Numerical Optimization, 1999

Hager and Zhang Algorithm 851: CG_DESCENT, a Conjugate Gradient Method with Guaranteed Descent, 2006

Note

The curvature criterion can be avoided by setting by setting curv_rtol=jnp.inf. The resulting algorithm will amount to a backtracking linesearch where a point satisfying sufficient decrease is searched by minimizing a quadratic or cubic approximation of the objective. This can be sufficient in practice and avoids having the linesearch spend many iterations trying to satisfy the small curvature criterion.

Note

The algorithm can support complex inputs.

See also

optax.value_and_grad_from_state() to make this method more efficient for non-stochastic objectives.