Optimizer Wrappers#

apply_if_finite(inner, max_consecutive_errors)

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

ApplyIfFiniteState(notfinite_count, ...)

State of the GradientTransformation returned by apply_if_finite.

flatten(inner)

Flattens parameters and gradients for init and update of inner transform.

lookahead(fast_optimizer, sync_period, ...)

Lookahead optimizer.

LookaheadParams(fast, slow)

Holds a pair of slow and fast parameters for the lookahead optimizer.

LookaheadState(fast_state, steps_since_sync)

State of the GradientTransformation returned by lookahead.

masked(inner, mask, *[, ...])

Mask updates so only some are transformed, the rest are passed through.

MaskedState(inner_state)

Maintains inner transform state for masked transformations.

MultiSteps(opt, every_k_schedule, ...)

An optimizer wrapper to accumulate gradients over multiple steps.

MultiStepsState(mini_step, gradient_step, ...)

State of the GradientTransformation returned by MultiSteps.

ShouldSkipUpdateFunction(*args, **kwargs)

skip_large_updates(updates, gradient_step, ...)

Returns True if the global norm square of updates is small enough.

skip_not_finite(updates, gradient_step, params)

Returns True iff any of the updates contains an inf or a NaN.

Apply if finite#

optax.apply_if_finite(inner: optax.GradientTransformation, max_consecutive_errors: int) optax.GradientTransformation[source]#

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

The purpose of this function is to prevent any optimization to happen if the gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the gradients, the wrapped optimizer ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimizer gives up and accepts the update.

Parameters:
  • inner – Inner transformation to be wrapped.

  • max_consecutive_errors – Maximum number of consecutive gradient updates containing NaNs or Infs that the wrapped optimizer will ignore. After that many ignored updates, the optimizer will give up and accept.

Returns:

New optax.GradientTransformationExtraArgs.

class optax.ApplyIfFiniteState(notfinite_count: Any, last_finite: Any, total_notfinite: Any, inner_state: Any)[source]#

State of the GradientTransformation returned by apply_if_finite.

notfinite_count#

Number of consecutive gradient updates containing an Inf or a NaN. This number is reset to 0 whenever a gradient update without an Inf or a NaN is done.

Type:

Any

last_finite#

Whether or not the last gradient update contained an Inf or a NaN.

Type:

Any

total_notfinite#

Total number of gradient updates containing an Inf or a NaN since this optimizer was initialised. This number is never reset. inner_state: The state of the inner GradientTransformation.

Type:

Any

Flatten#

optax.flatten(inner: base.GradientTransformation) base.GradientTransformationExtraArgs[source]#

Flattens parameters and gradients for init and update of inner transform.

This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage.

Parameters:

inner – Inner transformation to flatten inputs for.

Returns:

New optax.GradientTransformationExtraArgs

Lookahead#

optax.lookahead(fast_optimizer: optax.GradientTransformation, sync_period: jax.typing.ArrayLike, slow_step_size: jax.typing.ArrayLike, reset_state: bool = False) optax.GradientTransformation[source]#

Lookahead optimizer.

Performs steps with a fast optimizer and periodically updates a set of slow parameters. Optionally resets the fast optimizer state after synchronization by calling the init function of the fast optimizer.

Updates returned by the lookahead optimizer should not be modified before they are applied, otherwise fast and slow parameters are not synchronized correctly.

Parameters:
  • fast_optimizer – The optimizer to use in the inner loop of lookahead.

  • sync_period – Number of fast optimizer steps to take before synchronizing parameters. Must be >= 1.

  • slow_step_size – Step size of the slow parameter updates.

  • reset_state – Whether to reset the optimizer state of the fast optimizer after each synchronization.

Returns:

A optax.GradientTransformation with init and update functions. The updates passed to the update function should be calculated using the fast lookahead parameters only.

Example

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> fast_opt = optax.sgd(1e-2)
>>> opt = optax.lookahead(fast_opt, sync_period=5, slow_step_size=0.5)
>>> params = optax.LookaheadParams.init_synced(jnp.ones((2,)))
>>> state = opt.init(params)
>>> loss_fn = lambda p: jnp.sum(p**2)
>>> # Calculate gradients wrt the fast parameters
>>> grads = jax.grad(loss_fn)(params.fast)
>>> updates, state = opt.update(grads, state, params)
>>> params = optax.apply_updates(params, updates)
>>> # Calculate the eval loss wrt the slow parameters
>>> loss_fn(params.slow)
Array(2., dtype=float32)

References

Zhang et al, Lookahead Optimizer: k steps forward, 1 step back, 2019

class optax.LookaheadParams(fast: optax.Params, slow: optax.Params)[source]#

Holds a pair of slow and fast parameters for the lookahead optimizer.

Gradients should always be calculated with the fast parameters (i.e., params.fast). The slow parameters should be used for testing and inference as they generalize better. See the reference for a detailed discussion.

fast#

Fast parameters (use these for gradient computation).

Type:

base.Params

slow#

Slow parameters (use these for inference).

Type:

base.Params

References

Zhang et al, Lookahead Optimizer: k steps forward, 1 step back, 2019

class optax.LookaheadState(fast_state: optax.OptState, steps_since_sync: Array)[source]#

State of the GradientTransformation returned by lookahead.

fast_state#

Optimizer state of the fast optimizer.

Type:

base.OptState

steps_since_sync#

Number of fast optimizer steps taken since slow and fast parameters were synchronized.

Type:

jax.Array

Masked update#

optax.masked(inner: base.GradientTransformation, mask: base.PyTree | Callable[[base.Params], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Mask updates so only some are transformed, the rest are passed through.

For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. Since in many networks, these are the only 1D parameters, you may for instance create a mask function to mask them out as follows:

mask_fn = lambda p: jax.tree.map(lambda x: x.ndim != 1, p)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)

You may alternatively create the mask pytree upfront:

mask = jax.tree.map(lambda x: x.ndim != 1, params)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)

For the inner transform, state will only be stored for the parameters that have a mask value of True.

Note that, when using tree_map_params, it may be required to pass the argument is_leaf=lambda v: isinstance(v, optax.MaskedNode), if the tree map needs to take additional arguments with the same shape as the original input tree.

Parameters:
  • inner – Inner transformation to mask.

  • mask – a PyTree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip. The mask must be static for the gradient transformation to be jit-compilable.

  • mask_compatible_extra_args – whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Returns:

New optax.GradientTransformationExtraArgs wrapping inner.

class optax.MaskedState(inner_state: Any)[source]#

Maintains inner transform state for masked transformations.

Multi-step update#

class optax.MultiSteps(opt: optax.GradientTransformation, every_k_schedule: int | ~collections.abc.Callable[[TypeAliasForwardRef('jax.typing.ArrayLike')], TypeAliasForwardRef('jax.typing.ArrayLike')], use_grad_mean: bool = True, should_skip_update_fn: ~optax.transforms._accumulation.ShouldSkipUpdateFunction | None = None, accumulator_dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'jax.numpy.float32'>)[source]#

An optimizer wrapper to accumulate gradients over multiple steps.

This wrapper collects together the updates passed to its update function over consecutive steps until a given number of scheduled steps is reached. In each of these intermediate steps, the returned value from the optimizer is a tree of zeros of the same shape of the updates passed as input.

Once the scheduled number of intermediate ‘mini-steps’ has been reached, the gradients accumulated to the current time will be passed to the wrapped optimizer’s update function, (with the inner optimizer’s state being updated appropriately) and then returned to the caller. The wrapper’s accumulated gradients are then set back to zero and the process starts again.

The number of mini-steps per gradient update is controlled by a function, and can vary over training, this also allows varying batch size over training.

class optax.MultiStepsState(mini_step: jax.typing.ArrayLike, gradient_step: jax.typing.ArrayLike, inner_opt_state: Any, acc_grads: Any, skip_state: base.ArrayTree = ())[source]#

State of the GradientTransformation returned by MultiSteps.

mini_step#

current mini-step counter. At an update, this either increases by 1 or is reset to 0.

Type:

jax.typing.ArrayLike

gradient_step#

gradient step counter. This only increases after enough mini-steps have been accumulated.

Type:

jax.typing.ArrayLike

inner_opt_state#

the state of the wrapped optimizer.

Type:

Any

acc_grads#

accumulated gradients over multiple mini-steps.

Type:

Any

skip_state#

an arbitrarily py tree. This is only relevant when passing a should_skip_update_fn to MultiSteps.

Type:

base.ArrayTree

class optax.ShouldSkipUpdateFunction(*args, **kwargs)[source]#
optax.skip_large_updates(updates: optax.Updates, gradient_step: jax.typing.ArrayLike, params: TypeAliasForwardRef('optax.Params') | None, max_squared_norm: jax.typing.ArrayLike) tuple[Array, TypeAliasForwardRef('optax.ArrayTree')][source]#

Returns True if the global norm square of updates is small enough.

Parameters:
Returns:

  • First element is a scalar array of type bool.

  • Second element is a dictionary with keys: - should_skip: iff ||updates||^2 is greater than max_squared_norm. - norm_squared: overall norm square of the updates.

Return type:

A tuple

optax.skip_not_finite(updates: optax.Updates, gradient_step: jax.typing.ArrayLike, params: TypeAliasForwardRef('optax.Params') | None) tuple[Array, TypeAliasForwardRef('optax.ArrayTree')][source]#

Returns True iff any of the updates contains an inf or a NaN.

Parameters:
  • updates – see ShouldSkipUpdateFunction.

  • gradient_step – see ShouldSkipUpdateFunction.

  • params – see ShouldSkipUpdateFunction.

Returns:

  • First element is a scalar array of type bool.

  • Second element is a dictionary with keys: - should_skip: True iff updates contains an inf or a NaN. - num_not_finite: total number of inf and NaN found in updates.

Return type:

A tuple