optax.value_and_grad_from_state#
- optax.value_and_grad_from_state(value_fn: Callable[[...], TypeAliasForwardRef('jax.typing.ArrayLike')]) Callable[[...], tuple[TypeAliasForwardRef('jax.typing.ArrayLike'), TypeAliasForwardRef('optax.Updates')]][source]#
Alternative to
jax.value_and_grad()fetches value, grad from state.Line-search methods such as
optax.scale_by_backtracking_linesearch()require to compute the gradient and objective function at the candidate iterate. This objective value and gradient can be re-used in the next iteration to save some computations using this utility function.- Parameters:
value_fn โ function returning a scalar (float or array of dimension 1), amenable to differentiation in jax using
jax.value_and_grad().- Returns:
A callable akin to
jax.value_and_grad()that fetches value and grad from the state if present. If no value or grad are found or if multiple value and grads are found this function raises an error. If a value is found but is infinite or nan, the value and grad are computed usingjax.value_and_grad(). If the gradient found in the state is None, raises an Error.
Examples
>>> import optax >>> import jax.numpy as jnp >>> def fn(x): return jnp.sum(x ** 2) >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02