optax.ScaleByBacktrackingLinesearchState#

class optax.ScaleByBacktrackingLinesearchState(learning_rate: jax.typing.ArrayLike, value: jax.typing.ArrayLike, grad: base.Updates | None, info: BacktrackingLinesearchInfo)[source]#

State for optax.scale_by_backtracking_linesearch().

learning_rate#

learning rate computed at the end of a round of line-search, used to scale the update.

Type:

jax.typing.ArrayLike

value#

value of the objective computed at the end of a round of line-search. Can be reused using optax.value_and_grad_from_state().

Type:

jax.typing.ArrayLike

grad#

gradient of the objective computed at the end of a round of line-search if the line-search is instantiated with store_grad = True. Otherwise it is None. Can be reused using optax.value_and_grad_from_state().

Type:

Optional[base.Updates]

info#

information about the backtracking linesearch step, for debugging.

Type:

BacktrackingLinesearchInfo