optax.ScaleByLBFGSState#

class optax.ScaleByLBFGSState(count: jax.typing.ArrayLike, params: optax.Params, updates: optax.Params, diff_params_memory: optax.ArrayTree, diff_updates_memory: optax.ArrayTree, weights_memory: jax.typing.ArrayLike)[source]#

State for LBFGS solver.

count#

iteration of the algorithm.

Type:

jax.typing.ArrayLike

params#

current parameters.

Type:

base.Params

updates#

current updates.

Type:

base.Params

diff_params_memory#

represents a list of past parameters’ differences up to some predetermined memory_size fixed in optax.scale_by_lbfgs().

Type:

base.ArrayTree

diff_updates_memory#

represents a list of past gradients/updates’ differences up to some predetermined memory_size fixed in optax.scale_by_lbfgs().

Type:

base.ArrayTree

weights_memory#

list of past weights multiplying the rank one matrices defining the inverse Hessian approximation, see optax.scale_by_lbfgs() for more details.

Type:

jax.typing.ArrayLike