Optimizer Wrappers

Optimizer Wrappers#

apply_if_finite(inner, max_consecutive_errors)

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

ApplyIfFiniteState(notfinite_count, ...)

State of the GradientTransformation returned by apply_if_finite.

flatten(inner)

Flattens parameters and gradients for init and update of inner transform.

lookahead(fast_optimizer, sync_period, ...)

Lookahead optimizer.

LookaheadParams(fast, slow)

Holds a pair of slow and fast parameters for the lookahead optimizer.

LookaheadState(fast_state, steps_since_sync)

State of the GradientTransformation returned by lookahead.

masked(inner, mask, *[, ...])

Mask updates so only some are transformed, the rest are passed through.

MaskedState(inner_state)

Maintains inner transform state for masked transformations.

MultiSteps(opt, every_k_schedule, ...)

An optimizer wrapper to accumulate gradients over multiple steps.

MultiStepsState(mini_step, gradient_step, ...)

State of the GradientTransformation returned by MultiSteps.

ShouldSkipUpdateFunction(*args, **kwargs)

skip_large_updates(updates, gradient_step, ...)

Returns True if the global norm square of updates is small enough.

skip_not_finite(updates, gradient_step, params)

Returns True iff any of the updates contains an inf or a NaN.