Optimizer Wrappers#

apply_if_finite(inner, max_consecutive_errors)

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

ApplyIfFiniteState(notfinite_count, ...)

State of the GradientTransformation returned by apply_if_finite.

flatten(inner)

Flattens parameters and gradients for init and update of inner transform.

lookahead(fast_optimizer, sync_period, ...)

Lookahead optimizer.

LookaheadParams(fast, slow)

Holds a pair of slow and fast parameters for the lookahead optimizer.

LookaheadState(fast_state, steps_since_sync)

State of the GradientTransformation returned by lookahead.

masked(inner, mask, *[, ...])

Mask updates so only some are transformed, the rest are passed through.

MaskedState(inner_state)

Maintains inner transform state for masked transformations.

maybe_update(inner, should_update_fn)

rtype:

GradientTransformationExtraArgs

MaybeUpdateState

alias of ConditionallyTransformState

MultiSteps(opt, every_k_schedule[, ...])

An optimizer wrapper to accumulate gradients over multiple steps.

MultiStepsState(mini_step, gradient_step, ...)

State of the GradientTransformation returned by MultiSteps.

ShouldSkipUpdateFunction(*args, **kwargs)

skip_large_updates(updates, gradient_step, ...)

Returns True if the global norm square of updates is small enough.

skip_not_finite(updates, gradient_step, params)

Returns True iff any of the updates contains an inf or a NaN.

Apply if finite#

optax.apply_if_finite(inner, max_consecutive_errors)[source]#

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

The purpose of this function is to prevent any optimization to happen if the gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the gradients, the wrapped optimizer ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimizer gives up and accepts the update.

Parameters:
  • inner (GradientTransformation) – Inner transformation to be wrapped.

  • max_consecutive_errors (int) – Maximum number of consecutive gradient updates containing NaNs or Infs that the wrapped optimizer will ignore. After that many ignored updates, the optimizer will give up and accept.

Return type:

GradientTransformation

Returns:

New GradientTransformationExtraArgs.

class optax.ApplyIfFiniteState(notfinite_count: Any, last_finite: Any, total_notfinite: Any, inner_state: Any)[source]#

State of the GradientTransformation returned by apply_if_finite.

Fields:
notfinite_count: Number of consecutive gradient updates containing an Inf or

a NaN. This number is reset to 0 whenever a gradient update without an Inf or a NaN is done.

last_finite: Whether or not the last gradient update contained an Inf or a

NaN.

total_notfinite: Total number of gradient updates containing an Inf or

a NaN since this optimizer was initialised. This number is never reset.

inner_state: The state of the inner GradientTransformation.

notfinite_count: Any#

Alias for field number 0

last_finite: Any#

Alias for field number 1

total_notfinite: Any#

Alias for field number 2

inner_state: Any#

Alias for field number 3

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Flatten#

optax.flatten(inner)[source]#

Flattens parameters and gradients for init and update of inner transform.

This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage.

Parameters:

inner (GradientTransformation) – Inner transformation to flatten inputs for.

Return type:

GradientTransformationExtraArgs

Returns:

New GradientTransformationExtraArgs

Lookahead#

optax.lookahead(fast_optimizer, sync_period, slow_step_size, reset_state=False)[source]#

Lookahead optimizer.

Performs steps with a fast optimizer and periodically updates a set of slow parameters. Optionally resets the fast optimizer state after synchronization by calling the init function of the fast optimizer.

Updates returned by the lookahead optimizer should not be modified before they are applied, otherwise fast and slow parameters are not synchronized correctly.

References

[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

Parameters:
  • fast_optimizer (GradientTransformation) – The optimizer to use in the inner loop of lookahead.

  • sync_period (int) – Number of fast optimizer steps to take before synchronizing parameters. Must be >= 1.

  • slow_step_size (float) – Step size of the slow parameter updates.

  • reset_state (bool) – Whether to reset the optimizer state of the fast opimizer after each synchronization.

Return type:

GradientTransformation

Returns:

A GradientTransformation with init and update functions. The updates passed to the update function should be calculated using the fast lookahead parameters only.

class optax.LookaheadParams(fast: base.Params, slow: base.Params)[source]#

Holds a pair of slow and fast parameters for the lookahead optimizer.

Gradients should always be calculated with the fast parameters. The slow parameters should be used for testing and inference as they generalize better. See the reference for a detailed discussion.

References

[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

fast#

Fast parameters.

slow#

Slow parameters.

fast: base.Params#

Alias for field number 0

slow: base.Params#

Alias for field number 1

classmethod init_synced(params)[source]#

Initialize a pair of synchronized lookahead parameters.

Return type:

LookaheadParams

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.LookaheadState(fast_state: base.OptState, steps_since_sync: jnp.ndarray)[source]#

State of the GradientTransformation returned by lookahead.

fast_state#

Optimizer state of the fast optimizer.

steps_since_sync#

Number of fast optimizer steps taken since slow and fast parameters were synchronized.

fast_state: base.OptState#

Alias for field number 0

steps_since_sync: jnp.ndarray#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Masked update#

optax.masked(inner, mask, *, mask_compatible_extra_args=False)[source]#

Mask updates so only some are transformed, the rest are passed through.

For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. Since in many networks, these are the only 1D parameters, you may for instance create a mask function to mask them out as follows:

mask_fn = lambda p: jtu.tree_map(lambda x: x.ndim != 1, p)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)

You may alternatively create the mask pytree upfront:

mask = jtu.tree_map(lambda x: x.ndim != 1, params)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)

For the inner transform, state will only be stored for the parameters that have a mask value of True.

Note that, when using tree_map_params, it may be required to pass the argument is_leaf=lambda v: isinstance(v, optax.MaskedNode), if the tree map needs to take additional arguments with the same shape as the original input tree.

Parameters:
  • inner (optax.GradientTransformation) – Inner transformation to mask.

  • mask (Union[base.PyTree, Callable[[optax.Params], base.PyTree]]) – a PyTree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip. The mask must be static for the gradient transformation to be jit-compilable.

  • mask_compatible_extra_args (bool) – whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Return type:

optax.GradientTransformationExtraArgs

Returns:

New GradientTransformationExtraArgs wrapping inner.

class optax.MaskedState(inner_state: Any)[source]#

Maintains inner transform state for masked transformations.

inner_state: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Maybe update#

optax.maybe_update(inner, should_update_fn)[source]#
Return type:

GradientTransformationExtraArgs

optax.MaybeUpdateState[source]#

alias of ConditionallyTransformState

Multi-step update#

class optax.MultiSteps(opt, every_k_schedule, use_grad_mean=True, should_skip_update_fn=None)[source]#

An optimizer wrapper to accumulate gradients over multiple steps.

This wrapper collects together the updates passed to its update function over consecutive steps until a given number of scheduled steps is reached. In each of these intermediate steps, the returned value from the optimizer is a tree of zeros of the same shape of the updates passed as input.

Once the scheduled number of intermediate ‘mini-steps’ has been reached, the gradients accumulated to the current time will be passed to the wrapped optimizer’s update function, (with the inner optimizer’s state being updated appropriately) and then returned to the caller. The wrapper’s accumulated gradients are then set back to zero and the process starts again.

The number of mini-steps per gradient update is controlled by a function, and can vary over training, this also allows varying batch size over training.

__init__(opt, every_k_schedule, use_grad_mean=True, should_skip_update_fn=None)[source]#

Initialiser.

Parameters:
  • opt (GradientTransformation) – the wrapped optimizer.

  • every_k_schedule (Union[int, Callable[[Union[Array, ndarray, bool_, number]], Union[Array, ndarray, bool_, number]]]) –

    an int or a function.

    • As a function, it returns how many mini-steps should be accumulated in a single gradient step. Its only argument is the current gradient step count. By varying the returned value, users can vary the overall training batch size.

    • If an int, this is the constant number of mini-steps per gradient update.

  • use_grad_mean (bool) – if True (the default), gradients accumulated over multiple mini-steps are averaged. Otherwise, they are summed.

  • should_skip_update_fn (Optional[ShouldSkipUpdateFunction]) –

    if provided, this function is used to decide when to accept or reject the updates from a mini-step. When a mini-step is rejected, the inner state of MultiSteps is not updated. In other words, it is as if this mini-step never happened. For example:

    • to ignore updates containing inf or NaN, do should_skip_update_fn=skip_not_finite;

    • to ignore updates with a norm square larger then 42, do: should_skip_update_fn=functools.partial(skip_large_updates, max_norm_sq=42.)

    Note that the optimizer’s state optax.MultiStepsState contains a keyword argument skip_state in which debugging and monitoring information returned by should_skip_update_fn is written.

init(params)[source]#

Builds and returns initial MultiStepsState.

Return type:

MultiStepsState

update(updates, state, params=None, **extra_args)[source]#

Accumulates gradients and proposes non-zero updates every k_steps.

Return type:

tuple[optax.Updates, MultiStepsState]

class optax.MultiStepsState(mini_step: chex.Array, gradient_step: chex.Array, inner_opt_state: Any, acc_grads: Any, skip_state: chex.ArrayTree = ())[source]#

State of the GradientTransformation returned by MultiSteps.

Fields:
mini_step: current mini-step counter. At an update, this either increases by

1 or is reset to 0.

gradient_step: gradient step counter. This only increases after enough

mini-steps have been accumulated.

inner_opt_state: the state of the wrapped otpimiser. acc_grads: accumulated gradients over multiple mini-steps. skip_state: an arbitrarily py tree. This is only relevant when passing

a should_skip_update_fn to MultiSteps.

mini_step: chex.Array#

Alias for field number 0

gradient_step: chex.Array#

Alias for field number 1

inner_opt_state: Any#

Alias for field number 2

acc_grads: Any#

Alias for field number 3

skip_state: chex.ArrayTree#

Alias for field number 4

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ShouldSkipUpdateFunction(*args, **kwargs)[source]#
__call__(updates, gradient_step, params)[source]#

Returns true to indicate that updates should be skipped in a multi-step.

Parameters:
  • updates (optax.Updates) – The updates that the gradient transformation has proposed.

  • gradient_step (chex.Array) – The current gradient step (see MultiStepsState.gradient_step). This can be used for example to reject large gradients with an annealed maximum allowed gradient norm.

  • params (Optional[optax.Params]) – If known, the current params of the function being transformed.

Returns:

  • First element is an array with a single bool indicating whether or not the updates should be applied.

  • Second element is an arbitrary py-tree that will be stored in MultiStepsState.skip_state. Debugging info can be put here.

Return type:

A tuple

__init__(*args, **kwargs)[source]#
__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

optax.skip_large_updates(updates, gradient_step, params, max_squared_norm)[source]#

Returns True if the global norm square of updates is small enough.

Parameters:
  • updates (optax.Updates) – see ShouldSkipUpdateFunction.

  • gradient_step (chex.Array) – see ShouldSkipUpdateFunction.

  • params (Optional[optax.Params]) – see ShouldSkipUpdateFunction.

  • max_squared_norm (float) – max square norm that can be accepted in updates.

Returns:

  • First element is a scalar array of type bool.

  • Second element is a dictionary with keys: - should_skip: iff ||updates||^2 is greater than max_squared_norm. - norm_squared: overall norm square of the updates.

Return type:

A tuple

optax.skip_not_finite(updates, gradient_step, params)[source]#

Returns True iff any of the updates contains an inf or a NaN.

Parameters:
  • updates (optax.Updates) – see ShouldSkipUpdateFunction.

  • gradient_step (chex.Array) – see ShouldSkipUpdateFunction.

  • params (Optional[optax.Params]) – see ShouldSkipUpdateFunction.

Returns:

  • First element is a scalar array of type bool.

  • Second element is a dictionary with keys: - should_skip: True iff updates contains an inf or a NaN. - num_not_finite: total number of inf and NaN found in updates.

Return type:

A tuple