optax.LookaheadState#

class optax.LookaheadState(fast_state: optax.OptState, steps_since_sync: Array)[source]#

State of the GradientTransformation returned by lookahead.

fast_state#

Optimizer state of the fast optimizer.

Type:

base.OptState

steps_since_sync#

Number of fast optimizer steps taken since slow and fast parameters were synchronized.

Type:

jax.Array