optax.scale_by_backtracking_linesearch

optax.scale_by_backtracking_linesearch#

optax.scale_by_backtracking_linesearch(max_backtracking_steps: jax.typing.ArrayLike, slope_rtol: jax.typing.ArrayLike = 0.0001, decrease_factor: jax.typing.ArrayLike = 0.8, increase_factor: jax.typing.ArrayLike = 1.5, max_learning_rate: jax.typing.ArrayLike = 1.0, atol: jax.typing.ArrayLike = 0.0, rtol: jax.typing.ArrayLike = 0.0, store_grad: bool = False, verbose: bool = False) base.GradientTransformationExtraArgs[source]#

Backtracking line-search ensuring sufficient decrease (Armijo criterion).

Selects learning rate \(\eta\) such that it verifies the sufficient decrease criterion

\[f(w + \eta u) \leq (1+\delta)f(w) + \eta c \langle u, \nabla f(w) \rangle + \epsilon \,, \]

where

\(f\) is the function to minimize, \(w\) are the current parameters, \(\eta\) is the learning rate to find, \(u\) is the update direction, \(c\) is a coefficient (slope_rtol) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates), \(\delta\) is a relative tolerance (rtol), \(\epsilon\) is an absolute tolerance (atol).

The algorithm starts with a given guess of a learning rate and decrease it by decrease_factor until the criterion above is met.

Parameters:
  • max_backtracking_steps – maximum number of iterations for the line-search.

  • slope_rtol – relative tolerance w.r.t. to the slope. The sufficient decrease must be slope_rtol * lr * <grad, updates>, see formula above.

  • decrease_factor – decreasing factor to reduce learning rate.

  • increase_factor – increasing factor to increase learning rate guess. Setting it to 1. amounts to keep the current guess, setting it to math.inf amounts to start with max_learning_rate at each round.

  • max_learning_rate – maximum learning rate (learning rate guess clipped to this).

  • atol – absolute tolerance at which the criterion needs to be satisfied.

  • rtol – relative tolerance at which the criterion needs to be satisfied.

  • store_grad – whether to compute and store the gradient at the end of the linesearch. Since the function is called to compute the value to accept the learning rate, we can also access the gradient along the way. By doing that, we can directly reuse the value and the gradient computed at the end of the linesearch for the next iteration using optax.value_and_grad_from_state(). See the example above.

  • verbose – whether to print debugging information.

Returns:

A GradientTransformationExtraArgs, where the update function takes the following additional keyword arguments:

  • value: value of the function at the current params.

  • grad: gradient of the function at the current params.

  • value_fn: function returning the value of the function we seek to optimize.

  • **extra_args: additional keyword arguments, if the function needs additional arguments such as input data, they should be put there ( see example in this docstring).

Examples

An example on using the backtracking line-search with SGD:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> solver = optax.chain(
...    optax.sgd(learning_rate=1.),
...    optax.scale_by_backtracking_linesearch(max_backtracking_steps=15)
... )
>>> # Function with additional inputs other than params
>>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y)
>>> params = jnp.array([1., 2., 3.])
>>> opt_state = solver.init(params)
>>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.)
>>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,))
>>> opt_state = solver.init(params)
>>> print('Objective function: {:.2E}'.format(fn(params, x, y)))
Objective function: 5.00E+01
>>> for x, y in zip(xs, ys):
...   value, grad = jax.value_and_grad(fn)(params, x, y)
...   updates, opt_state = solver.update(
...       grad,
...       opt_state,
...       params,
...       value=value,
...       grad=grad,
...       value_fn=fn,
...       x=x,
...       y=y
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params, x, y)))
Objective function: 3.86E+01
Objective function: 2.50E+01
Objective function: 1.34E+01
Objective function: 5.87E+00
Objective function: 5.81E+00

A similar example, but with a non-stochastic function where we can reuse the value and the gradient computed at the end of the linesearch:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> # Function without extra arguments
>>> def fn(params): return jnp.sum(params ** 2)
>>> params = jnp.array([1., 2., 3.])
>>> # In this case we can store value and grad with the store_grad field
>>> # and reuse them using optax.value_and_grad_state_from_state
>>> solver = optax.chain(
...    optax.sgd(learning_rate=1.),
...    optax.scale_by_backtracking_linesearch(
...        max_backtracking_steps=15, store_grad=True
...    )
... )
>>> opt_state = solver.init(params)
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 5.04E+00
Objective function: 1.81E+00
Objective function: 6.53E-01
Objective function: 2.35E-01
Objective function: 8.47E-02

References

Vaswani et al., Painless Stochastic Gradient, 2019

Nocedal & Wright, Numerical Optimization, 1999

Warning

The sufficient decrease criterion might be impossible to satisfy for some update directions. To guarantee a non-trivial solution for the sufficient decrease criterion, a descent direction for updates (\(u\)) is required. An update (\(u\)) is considered a descent direction if the derivative of \(f(w + \eta u)\) at \(\eta = 0\) (i.e., \(\langle u, \nabla f(w)\rangle\)) is negative. This condition is automatically satisfied when using optax.sgd() (without momentum), but may not hold true for other optimizers like optax.adam().

More generally, when chained with other transforms as optax.chain(opt_1, ..., opt_k, scale_by_backtraking_linesearch(max_backtracking_steps=...), opt_kplusone, ..., opt_n), the updates returned by chaining opt_1, ..., opt_k must be a descent direction. However, any transform after the backtracking line-search doesn’t necessarily need to satisfy the descent direction property (one could for example use momentum).

Note

The algorithm can support complex inputs.

See also

optax.value_and_grad_from_state() to make this method more efficient for non-stochastic objectives.

Added in version 0.2.0.