optax.value_and_grad_from_state

optax.value_and_grad_from_state#

optax.value_and_grad_from_state(value_fn: Callable[[...], TypeAliasForwardRef('jax.typing.ArrayLike')]) Callable[[...], tuple[TypeAliasForwardRef('jax.typing.ArrayLike'), TypeAliasForwardRef('optax.Updates')]][source]#

Alternative to jax.value_and_grad() fetches value, grad from state.

Line-search methods such as optax.scale_by_backtracking_linesearch() require to compute the gradient and objective function at the candidate iterate. This objective value and gradient can be re-used in the next iteration to save some computations using this utility function.

Parameters:

value_fn โ€“ function returning a scalar (float or array of dimension 1), amenable to differentiation in jax using jax.value_and_grad().

Returns:

A callable akin to jax.value_and_grad() that fetches value and grad from the state if present. If no value or grad are found or if multiple value and grads are found this function raises an error. If a value is found but is infinite or nan, the value and grad are computed using jax.value_and_grad(). If the gradient found in the state is None, raises an Error.

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> def fn(x): return jnp.sum(x ** 2)
>>> solver = optax.chain(
...     optax.sgd(learning_rate=1.),
...     optax.scale_by_backtracking_linesearch(
...         max_backtracking_steps=15, store_grad=True
...     )
... )
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 5.04E+00
Objective function: 1.81E+00
Objective function: 6.53E-01
Objective function: 2.35E-01
Objective function: 8.47E-02