optax.scale_by_backtracking_linesearch#
- optax.scale_by_backtracking_linesearch(max_backtracking_steps: jax.typing.ArrayLike, slope_rtol: jax.typing.ArrayLike = 0.0001, decrease_factor: jax.typing.ArrayLike = 0.8, increase_factor: jax.typing.ArrayLike = 1.5, max_learning_rate: jax.typing.ArrayLike = 1.0, atol: jax.typing.ArrayLike = 0.0, rtol: jax.typing.ArrayLike = 0.0, store_grad: bool = False, verbose: bool = False) base.GradientTransformationExtraArgs[source]#
Backtracking line-search ensuring sufficient decrease (Armijo criterion).
Selects learning rate \(\eta\) such that it verifies the sufficient decrease criterion
\[f(w + \eta u) \leq (1+\delta)f(w) + \eta c \langle u, \nabla f(w) \rangle + \epsilon \,, \]where
\(f\) is the function to minimize, \(w\) are the current parameters, \(\eta\) is the learning rate to find, \(u\) is the update direction, \(c\) is a coefficient (
slope_rtol) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates), \(\delta\) is a relative tolerance (rtol), \(\epsilon\) is an absolute tolerance (atol).The algorithm starts with a given guess of a learning rate and decrease it by
decrease_factoruntil the criterion above is met.- Parameters:
max_backtracking_steps β maximum number of iterations for the line-search.
slope_rtol β relative tolerance w.r.t. to the slope. The sufficient decrease must be slope_rtol * lr * <grad, updates>, see formula above.
decrease_factor β decreasing factor to reduce learning rate.
increase_factor β increasing factor to increase learning rate guess. Setting it to 1. amounts to keep the current guess, setting it to
math.infamounts to start withmax_learning_rateat each round.max_learning_rate β maximum learning rate (learning rate guess clipped to this).
atol β absolute tolerance at which the criterion needs to be satisfied.
rtol β relative tolerance at which the criterion needs to be satisfied.
store_grad β whether to compute and store the gradient at the end of the linesearch. Since the function is called to compute the value to accept the learning rate, we can also access the gradient along the way. By doing that, we can directly reuse the value and the gradient computed at the end of the linesearch for the next iteration using
optax.value_and_grad_from_state(). See the example above.verbose β whether to print debugging information.
- Returns:
A
GradientTransformationExtraArgs, where theupdatefunction takes the following additional keyword arguments:value: value of the function at the current params.grad: gradient of the function at the current params.value_fn: function returning the value of the function we seek to optimize.**extra_args: additional keyword arguments, if the function needs additional arguments such as input data, they should be put there ( see example in this docstring).
Examples
An example on using the backtracking line-search with SGD:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch(max_backtracking_steps=15) ... ) >>> # Function with additional inputs other than params >>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y) >>> params = jnp.array([1., 2., 3.]) >>> opt_state = solver.init(params) >>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.) >>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,)) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 5.00E+01 >>> for x, y in zip(xs, ys): ... value, grad = jax.value_and_grad(fn)(params, x, y) ... updates, opt_state = solver.update( ... grad, ... opt_state, ... params, ... value=value, ... grad=grad, ... value_fn=fn, ... x=x, ... y=y ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 3.86E+01 Objective function: 2.50E+01 Objective function: 1.34E+01 Objective function: 5.87E+00 Objective function: 5.81E+00
A similar example, but with a non-stochastic function where we can reuse the value and the gradient computed at the end of the linesearch:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> # Function without extra arguments >>> def fn(params): return jnp.sum(params ** 2) >>> params = jnp.array([1., 2., 3.]) >>> # In this case we can store value and grad with the store_grad field >>> # and reuse them using optax.value_and_grad_state_from_state >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02
References
Vaswani et al., Painless Stochastic Gradient, 2019
Nocedal & Wright, Numerical Optimization, 1999
Warning
The sufficient decrease criterion might be impossible to satisfy for some update directions. To guarantee a non-trivial solution for the sufficient decrease criterion, a descent direction for updates (\(u\)) is required. An update (\(u\)) is considered a descent direction if the derivative of \(f(w + \eta u)\) at \(\eta = 0\) (i.e., \(\langle u, \nabla f(w)\rangle\)) is negative. This condition is automatically satisfied when using
optax.sgd()(without momentum), but may not hold true for other optimizers likeoptax.adam().More generally, when chained with other transforms as
optax.chain(opt_1, ..., opt_k, scale_by_backtraking_linesearch(max_backtracking_steps=...), opt_kplusone, ..., opt_n), the updates returned by chainingopt_1, ..., opt_kmust be a descent direction. However, any transform after the backtracking line-search doesnβt necessarily need to satisfy the descent direction property (one could for example use momentum).Note
The algorithm can support complex inputs.
See also
optax.value_and_grad_from_state()to make this method more efficient for non-stochastic objectives.Added in version 0.2.0.