optax.ScaleByLBFGSState#
- class optax.ScaleByLBFGSState(count: jax.typing.ArrayLike, params: optax.Params, updates: optax.Params, diff_params_memory: optax.ArrayTree, diff_updates_memory: optax.ArrayTree, weights_memory: jax.typing.ArrayLike)[source]#
State for LBFGS solver.
- count#
iteration of the algorithm.
- Type:
jax.typing.ArrayLike
- params#
current parameters.
- Type:
base.Params
- updates#
current updates.
- Type:
base.Params
- diff_params_memory#
represents a list of past parameters’ differences up to some predetermined
memory_sizefixed inoptax.scale_by_lbfgs().- Type:
base.ArrayTree
- diff_updates_memory#
represents a list of past gradients/updates’ differences up to some predetermined
memory_sizefixed inoptax.scale_by_lbfgs().- Type:
base.ArrayTree
- weights_memory#
list of past weights multiplying the rank one matrices defining the inverse Hessian approximation, see
optax.scale_by_lbfgs()for more details.- Type:
jax.typing.ArrayLike