optax.ScaleByBacktrackingLinesearchState#
- class optax.ScaleByBacktrackingLinesearchState(learning_rate: jax.typing.ArrayLike, value: jax.typing.ArrayLike, grad: base.Updates | None, info: BacktrackingLinesearchInfo)[source]#
State for
optax.scale_by_backtracking_linesearch().- learning_rate#
learning rate computed at the end of a round of line-search, used to scale the update.
- Type:
jax.typing.ArrayLike
- value#
value of the objective computed at the end of a round of line-search. Can be reused using
optax.value_and_grad_from_state().- Type:
jax.typing.ArrayLike
- grad#
gradient of the objective computed at the end of a round of line-search if the line-search is instantiated with store_grad = True. Otherwise it is None. Can be reused using
optax.value_and_grad_from_state().- Type:
Optional[base.Updates]
- info#
information about the backtracking linesearch step, for debugging.
- Type:
BacktrackingLinesearchInfo