Transformations

Contents

Transformations#

adaptive_grad_clip(clipping[, eps])

Clips updates to be at most clipping * parameter_norm, unit-wise.

AdaptiveGradClipState

alias of EmptyState

add_decayed_weights([weight_decay, mask])

Add parameter scaled by weight_decay.

AddDecayedWeightsState

alias of EmptyState

add_noise(eta, gamma, seed)

Add gradient noise.

AddNoiseState(count, rng_key)

State for adding gradient noise.

apply_every([k])

Accumulate gradients and apply them every k steps.

ApplyEvery(count, grad_acc)

Contains a counter and a gradient accumulator.

bias_correction(moment, decay, count)

Performs bias correction.

centralize()

Centralize gradients.

clip(max_delta)

Clips updates element-wise, to be in [-max_delta, +max_delta].

clip_by_block_rms(threshold)

Clips updates to a max rms for the gradient of each param vector or matrix.

ClipState

alias of EmptyState

clip_by_global_norm(max_norm)

Clips updates using their global norm.

ClipByGlobalNormState

alias of EmptyState

ema(decay[, debias, accumulator_dtype])

Compute an exponential moving average of past updates.

EmaState(count, ema)

Holds an exponential moving average of past updates.

EmptyState()

An empty state for the simplest stateless transformations.

global_norm(updates)

Compute the global norm across a nested structure of tensors.

GradientTransformation(init, update)

A pair of pure functions implementing a gradient transformation.

GradientTransformationExtraArgs(init, update)

A specialization of GradientTransformation that supports extra args.

identity()

Stateless identity transformation that leaves input gradients untouched.

keep_params_nonnegative()

Modifies the updates to keep parameters non-negative, i.e. >= 0.

NonNegativeParamsState

alias of EmptyState

OptState

alias of Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

Params

alias of Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

per_example_global_norm_clip(grads, l2_norm_clip)

Applies gradient clipping per-example using their global norm.

per_example_layer_norm_clip(grads, ...[, ...])

Applies gradient clipping per-example using per-layer norms.

scale(step_size)

Scale updates by some fixed scalar step_size.

ScaleState

alias of EmptyState

scale_by_adadelta([rho, eps])

Rescale updates according to the Adadelta algorithm.

ScaleByAdaDeltaState(e_g, e_x)

State for the rescaling by Adadelta algoritm.

scale_by_adam([b1, b2, eps, eps_root, ...])

Rescale updates according to the Adam algorithm.

scale_by_adamax([b1, b2, eps])

Rescale updates according to the Adamax algorithm.

ScaleByAdamState(count, mu, nu)

State for the Adam algorithm.

scale_by_amsgrad([b1, b2, eps, eps_root, ...])

Rescale updates according to the AMSGrad algorithm.

ScaleByAmsgradState(count, mu, nu, nu_max)

State for the AMSGrad algorithm.

scale_by_backtracking_linesearch(...[, ...])

Backtracking line-search ensuring sufficient decrease (Armijo criterion).

ScaleByBacktrackingLinesearchState(...[, grad])

State for optax.scale_by_backtracking_linesearch().

scale_by_belief([b1, b2, eps, eps_root])

Rescale updates according to the AdaBelief algorithm.

ScaleByBeliefState(count, mu, nu)

State for the rescaling by AdaBelief algorithm.

scale_by_factored_rms([factored, ...])

Scaling by a factored estimate of the gradient rms (as in Adafactor).

FactoredState(count, v_row, v_col, v)

Overall state of the gradient transformation.

scale_by_learning_rate(learning_rate, *[, ...])

Scale by the (negative) learning rate (either as scalar or as schedule).

scale_by_lion([b1, b2, mu_dtype])

Rescale updates according to the Lion algorithm.

ScaleByLionState(count, mu)

State for the Lion algorithm.

scale_by_novograd([b1, b2, eps, eps_root, ...])

Computes NovoGrad updates.

ScaleByNovogradState(count, mu, nu)

State for Novograd.

scale_by_optimistic_gradient([alpha, beta])

Compute generalized optimistic gradients.

scale_by_param_block_norm([min_scale])

Scale updates for each param block by the norm of that block's parameters.

scale_by_param_block_rms([min_scale])

Scale updates by rms of the gradient for each param vector or matrix.

scale_by_polyak([f_min, max_learning_rate, eps])

Scales the update by Polyak's step-size.

scale_by_radam([b1, b2, eps, eps_root, ...])

Rescale updates according to the Rectified Adam algorithm.

scale_by_rms([decay, eps, initial_scale])

Rescale updates by the root of the exp.

ScaleByRmsState(nu)

State for exponential root mean-squared (RMS)-normalized updates.

scale_by_rprop(learning_rate[, eta_minus, ...])

Scale with the Rprop optimizer.

ScaleByRpropState(step_sizes, prev_updates)

scale_by_rss([initial_accumulator_value, eps])

Rescale updates by the root of the sum of all squared gradients to date.

ScaleByRssState(sum_of_squares)

State holding the sum of gradient squares to date.

scale_by_schedule(step_size_fn)

Scale updates using a custom schedule for the step_size.

ScaleByScheduleState(count)

Maintains count for scale scheduling.

scale_by_sm3([b1, b2, eps])

Scale updates by sm3.

ScaleBySM3State(mu, nu)

State for the SM3 algorithm.

scale_by_stddev([decay, eps, initial_scale])

Rescale updates by the root of the centered exp.

ScaleByRStdDevState(mu, nu)

State for centered exponential moving average of squares of updates.

scale_by_trust_ratio([min_norm, ...])

Scale updates by trust ratio.

ScaleByTrustRatioState()

The scale and decay trust ratio transformation is stateless.

scale_by_yogi([b1, b2, eps, eps_root, ...])

Rescale updates according to the Yogi algorithm.

set_to_zero()

Stateless transformation that maps input gradients to zero.

stateless(f)

Creates a stateless transformation from an update-like function.

stateless_with_tree_map(f)

Creates a stateless transformation from an update-like function for arrays.

trace(decay[, nesterov, accumulator_dtype])

Compute a trace of past updates.

TraceState(trace)

Holds an aggregation of past updates.

TransformInitFn(*args, **kwargs)

A callable type for the init step of a GradientTransformation.

TransformUpdateFn(*args, **kwargs)

A callable type for the update step of a GradientTransformation.

update_infinity_moment(updates, moments, ...)

Compute the exponential moving average of the infinity norm.

update_moment(updates, moments, decay, order)

Compute the exponential moving average of the order-th moment.

update_moment_per_elem_norm(updates, ...)

Compute the EMA of the order-th moment of the element-wise norm.

Updates

alias of Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

with_extra_args_support(tx)

Wraps a gradient transformation, so that it ignores extra args.

zero_nans()

A transformation which replaces NaNs with 0.

ZeroNansState(found_nan)

Contains a tree.

Types#

class optax.GradientTransformation(init: TransformInitFn, update: TransformUpdateFn)[source]#

A pair of pure functions implementing a gradient transformation.

Optax optimizers are all implemented as _gradient transformations. A gradient transformation is defined to be a pair of pure functions, which are combined together in a NamedTuple so that they can be referred to by name.

Note that an extended API is provided for users wishing to build optimizers that take additional arguments during the update step. For more details, see GradientTransoformationExtraArgs.

Since gradient transformations do not contain any internal state, all stateful optimizer properties (such as the current step count when using optimizer scheduels, or momemtum values) are passed through optax gradient transformations by using the optimizer _state_ pytree. Each time a gradient transformation is applied, a new state is computed and returned, ready to be passed to the next call to the gradient transformation.

Since gradient transformations are pure, idempotent functions, the only way to change the behaviour of a gradient transformation between steps, is to change the values in the optimizer state. To see an example of mutating the optimizer state in order to control the behaviour of an optax gradient transformation, see the meta-learning example in the optax documentation.

init#

A pure function which, when called with an example instance of the parameters whose gradients will be transformed, returns a pytree containing the initial value for the optimizer state.

update#

A pure function which takes as input a pytree of updates (with the same tree structure as the original params pytree passed to init), the previous optimizer state (which may have been initialized using the init function), and optionally the current params. The update function then returns the computed gradient updates, and a new optimizer state.

init: TransformInitFn#

Alias for field number 0

update: TransformUpdateFn#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.GradientTransformationExtraArgs(init: TransformInitFn, update: TransformUpdateFn)[source]#

A specialization of GradientTransformation that supports extra args.

Extends the existing GradientTransformation interface by adding support for passing extra arguments to the update function.

Note that if no extra args are provided, then the API of this function is identical to the case of TransformUpdateFn. This means that we can safely wrap any gradient transformation (that does not support extra args) as one that does. The new gradient transformation will accept (and ignore) any extra arguments that a user might pass to it. This is the behavior implemented by optax.with_extra_args_support().

update#

Overrides the type signature of the update in the base type to accept extra arguments.

class optax.TransformInitFn(*args, **kwargs)[source]#

A callable type for the init step of a GradientTransformation.

The init step takes a tree of params and uses these to construct an arbitrary structured initial state for the gradient transformation. This may hold statistics of the past updates or any other non static information.

__call__(params)[source]#

The init function.

Parameters:

params (Params) – The initial value of the parameters.

Return type:

OptState

Returns:

The initial state of the gradient transformation.

__init__(*args, **kwargs)[source]#
__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

class optax.TransformUpdateFn(*args, **kwargs)[source]#

A callable type for the update step of a GradientTransformation.

The update step takes a tree of candidate parameter updates (e.g. their gradient with respect to some loss), an arbitrary structured state, and the current params of the model being optimised. The params argument is optional, it must however be provided when using transformations that require access to the current values of the parameters.

For the case where additional arguments are required, an alternative interface may be used, see TransformUpdateExtraArgsFn for details.

__call__(updates, state, params=None)[source]#

The update function.

Parameters:
  • updates (Updates) – A tree of candidate updates.

  • state (OptState) – The state of the gradient transformation.

  • params (Optional[Params]) – (Optionally) the current value of the parameters.

Return type:

tuple[Updates, OptState]

Returns:

The transformed updates, and the updated state.

__init__(*args, **kwargs)[source]#
__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

optax.OptState#

alias of Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

optax.Params#

alias of Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

optax.Updates#

alias of Union[Array, ndarray, bool_, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

Transformations and states#

optax.adaptive_grad_clip(clipping, eps=0.001)[source]#

Clips updates to be at most clipping * parameter_norm, unit-wise.

References

[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)

Parameters:
  • clipping (float) – The maximum allowed ratio of update norm to parameter norm.

  • eps (float) – An epsilon term to prevent clipping of zero-initialized params.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.AdaptiveGradClipState[source]#

alias of EmptyState

optax.add_decayed_weights(weight_decay=0.0, mask=None)[source]#

Add parameter scaled by weight_decay.

Parameters:
  • weight_decay (Union[float, jax.Array]) – A scalar weight decay rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type:

optax.GradientTransformation

Returns:

A GradientTransformation object.

optax.AddDecayedWeightsState[source]#

alias of EmptyState

optax.add_noise(eta, gamma, seed)[source]#

Add gradient noise.

References

[Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)

Parameters:
  • eta (float) – Base variance of the gaussian noise added to the gradient.

  • gamma (float) – Decay exponent for annealing of the variance.

  • seed (int) – Seed for random number generation.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.AddNoiseState(count: chex.Array, rng_key: chex.PRNGKey)[source]#

State for adding gradient noise. Contains a count for annealing.

count: Union[Array, ndarray, bool_, number]#

Alias for field number 0

rng_key: Array#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.apply_every(k=1)[source]#

Accumulate gradients and apply them every k steps.

Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using apply_every with a batch size of N/2 and k=2 is not necessarily equivalent to not using apply_every with a batch size of N. If this equivalence is important for you, consider using the optax.MultiSteps.

Parameters:

k (int) – Emit non-zero gradients every k steps, otherwise accumulate them.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ApplyEvery(count: chex.Array, grad_acc: base.Updates)[source]#

Contains a counter and a gradient accumulator.

count: chex.Array#

Alias for field number 0

grad_acc: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.bias_correction(moment, decay, count)[source]#

Performs bias correction. It becomes a no-op as count goes to infinity.

optax.centralize()[source]#

Centralize gradients.

References

[Yong et al, 2020](https://arxiv.org/abs/2004.01461)

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.clip(max_delta)[source]#

Clips updates element-wise, to be in [-max_delta, +max_delta].

Parameters:

max_delta (Union[Array, ndarray, bool_, number, float, int]) – The maximum absolute value for each element in the update.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.clip_by_block_rms(threshold)[source]#

Clips updates to a max rms for the gradient of each param vector or matrix.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters:

threshold (float) – The maximum rms for the gradient of each param vector or matrix.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.ClipState[source]#

alias of EmptyState

optax.clip_by_global_norm(max_norm)[source]#

Clips updates using their global norm.

References

[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

Parameters:

max_norm (float) – The maximum global norm for an update.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.ClipByGlobalNormState[source]#

alias of EmptyState

optax.ema(decay, debias=True, accumulator_dtype=None)[source]#

Compute an exponential moving average of past updates.

Note: trace and ema have very similar but distinct updates; ema = decay * ema + (1-decay) * t, while trace = decay * trace + t. Both are frequently found in the optimization literature.

Parameters:
  • decay (float) – Decay rate for the exponential moving average.

  • debias (bool) – Whether to debias the transformed gradient.

  • accumulator_dtype (Optional[Any]) – Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.EmaState(count: chex.Array, ema: base.Params)[source]#

Holds an exponential moving average of past updates.

count: chex.Array#

Alias for field number 0

ema: base.Params#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.EmptyState[source]#

An empty state for the simplest stateless transformations.

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls)#

Create new instance of EmptyState()

optax.global_norm(updates)[source]#

Compute the global norm across a nested structure of tensors.

Return type:

Union[Array, ndarray, bool_, number]

optax.identity()[source]#

Stateless identity transformation that leaves input gradients untouched.

This function passes through the gradient updates unchanged.

Note, this should not to be confused with set_to_zero, which maps the input updates to zero - which is the transform required for the model parameters to be left unchanged when the updates are applied to them.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.keep_params_nonnegative()[source]#

Modifies the updates to keep parameters non-negative, i.e. >= 0.

This transformation ensures that parameters after the update will be larger than or equal to zero. In a chain of transformations, this should be the last one.

WARNING: the transformation expects input params to be non-negative. When params is negative the transformed update will move them to 0.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.NonNegativeParamsState[source]#

alias of EmptyState

optax.per_example_global_norm_clip(grads, l2_norm_clip)[source]#

Applies gradient clipping per-example using their global norm.

References

[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)

Parameters:
  • grads (list[Union[Array, ndarray, bool_, number]]) – flattened update; the function expects these to have a batch dimension on the 0th axis.

  • l2_norm_clip (float) – maximum L2 norm of the per-example gradients.

Return type:

tuple[list[Union[Array, ndarray, bool_, number]], Array]

Returns:

A tuple containing sum of the clipped per-example grads, and the number of per-example grads that were clipped.

optax.per_example_layer_norm_clip(grads, global_l2_norm_clip, uniform=True, eps=1e-08)[source]#

Applies gradient clipping per-example using per-layer norms.

References

[McMahan et al, 2012](https://arxiv.org/abs/1710.06963)]

Parameters:
  • grads (list[Union[Array, ndarray, bool_, number]]) – flattened update; i.e. a list of gradients in which each item is the gradient for one layer; the function expects these to have a batch dimension on the 0th axis.

  • global_l2_norm_clip (float) – overall L2 clip norm to use.

  • uniform (bool) – If True, per-layer clip norm is global_l2_norm_clip/sqrt(L), where L is the number of layers. Otherwise, per-layer clip norm is global_l2_norm_clip * sqrt(f), where f is the fraction of total model parameters that are in this layer.

  • eps (float) – Small positive value to add to norms to avoid possible division by zero.

Let C = global_l2_norm_clip value. Then per-layer clipping is done as follows: (1) If uniform is True, each of the K layers has an individual clip

norm of C / sqrt(K).

  1. If uniform is False, each of the K layers has an individual clip norm of C * sqrt(D_i / D) where D_i is the number of parameters in layer i, and D is the total number of parameters in the model.

Return type:

tuple[list[Union[Array, ndarray, bool_, number]], list[Union[Array, ndarray, bool_, number]]]

Returns:

A tuple containing sum of the clipped per-example grads and the number of per-example grads that were clipped for each layer.

optax.scale(step_size)[source]#

Scale updates by some fixed scalar step_size.

Parameters:

step_size (float) – A scalar corresponding to a fixed scaling factor for updates.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.ScaleState[source]#

alias of EmptyState

optax.scale_by_adadelta(rho=0.9, eps=1e-06)[source]#

Rescale updates according to the Adadelta algorithm.

References

[Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf)

Parameters:
  • rho (float) – A coefficient used for computing a running average of squared gradients.

  • eps (float) – Term added to the denominator to improve numerical stability.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByAdaDeltaState(e_g: base.Updates, e_x: base.Updates)[source]#

State for the rescaling by Adadelta algoritm.

e_g: base.Updates#

Alias for field number 0

e_x: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, *, nesterov=False)[source]#

Rescale updates according to the Adam algorithm.

References

Kingma et al, Adam: A Method for Stochastic Optimization, 2014

Dozat, Incorporating Nesterov Momentum into Adam 2016

Warning

PyTorch and optax’s adam follow Algorithm 1 of the Kingma and Ba’s Adam paper, if reproducing old results note that TensorFlow used instead the formulation just before Section 2.1 of the paper. See deepmind/optax#571 for more detail.

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • mu_dtype (Union[str, type[Any], dtype, SupportsDType, None]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov (bool) – Whether to use Nesterov momentum. The variant of Adam with Nesterov momentum is described in [Dozat 2016]

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.scale_by_adamax(b1=0.9, b2=0.999, eps=1e-08)[source]#

Rescale updates according to the Adamax algorithm.

References

[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted maximum of grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByAdamState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#

State for the Adam algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_amsgrad(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#

Rescale updates according to the AMSGrad algorithm.

References

[Reddi et al, 2018](https://openreview.net/forum?id=ryQu7f-RZ)

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • mu_dtype (Union[str, type[Any], dtype, SupportsDType, None]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByAmsgradState(count: chex.Array, mu: base.Updates, nu: base.Updates, nu_max: base.Updates)[source]#

State for the AMSGrad algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

nu_max: base.Updates#

Alias for field number 3

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_backtracking_linesearch(max_backtracking_steps, slope_rtol=0.0001, decrease_factor=0.8, increase_factor=1.5, max_learning_rate=1.0, atol=0.0, rtol=0.0, store_grad=False)[source]#

Backtracking line-search ensuring sufficient decrease (Armijo criterion).

Selects learning rate \(\gamma\) such that it verifies the decrease condition

\[f(w + \gamma u) \leq (1+\delta)f(w) + \gamma c \langle u, \nabla f(w) \rangle + \epsilon \,, \]

where \(f\) is the function to minimize, \(\gamma\) is the learning rate to find, \(u\) is the update direction, \(c\) is a coefficient (slope_rtol) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates), \(\delta\) is a relative tolerance (rtol), and \(\epsilon\) is an absolute tolerance (atol).

The algorithm starts with a given guess of a learning rate and decrease it by decrease_factor until the criterion above is met.

Warning

The sufficient decrease condition might be impossible to satisfy for some update directions. To guarantee a non-trivial solution for the sufficient decrease condition, employ a descent direction for updates (\(u\)). An update (\(u\)) is considered a descent direction if the derivative of \(f(w + \gamma u)\) at \(\gamma = 0\) (i.e., \(\langle u, \nabla f(w)\rangle\)) is negative. This condition is automatically satisfied when using optax.sgd() (without momentum), but may not hold true for other optimizers like optax.adam().

More generally, when chained with other transforms as optax.chain(opt_1, ..., opt_k, scale_by_backtraking_linesearch(max_backtracking_steps=...), opt_kplusone, ..., opt_n), the updates returned by chaining opt_1, ..., opt_k must be a descent direction. However, any transform after the backtracking line-search doesn’t necessarily need to satisfy the descent direction property (one could for example use momentum).

See also

optax.value_and_grad_from_state() to make this method more efficient for non-stochastic objectives.

New in version 0.2.0.

Examples

An example on using the backtracking line-search with SGD:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> solver = optax.chain(
...    optax.sgd(learning_rate=1.),
...    optax.scale_by_backtracking_linesearch(max_backtracking_steps=15)
... )
>>> # Function with additional inputs other than params
>>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y)
>>> params = jnp.array([1., 2., 3.])
>>> opt_state = solver.init(params)
>>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.)
>>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,))
>>> opt_state = solver.init(params)
>>> print('Objective function: {:.2E}'.format(fn(params, x, y)))
Objective function: 5.00E+01
>>> for x, y in zip(xs, ys):
...   value, grad = jax.value_and_grad(fn)(params, x, y)
...   updates, opt_state = solver.update(
...       grad,
...       opt_state,
...       params,
...       value=value,
...       grad=grad,
...       value_fn=fn,
...       x=x,
...       y=y
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params, x, y)))
Objective function: 3.86E+01
Objective function: 2.50E+01
Objective function: 1.34E+01
Objective function: 5.87E+00
Objective function: 5.81E+00

A similar example, but with a non-stochastic function where we can reuse the value and the gradient computed at the end of the linesearch:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> # Function without extra arguments
>>> def fn(params): return jnp.sum(params ** 2)
>>> params = jnp.array([1., 2., 3.])
>>> # In this case we can store value and grad with the store_grad field
>>> # and reuse them using optax.value_and_grad_state_from_state
>>> solver = optax.chain(
...    optax.sgd(learning_rate=1.),
...    optax.scale_by_backtracking_linesearch(
...        max_backtracking_steps=15, store_grad=True
...    )
... )
>>> opt_state = solver.init(params)
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 5.04E+00
Objective function: 1.81E+00
Objective function: 6.53E-01
Objective function: 2.35E-01
Objective function: 8.47E-02

References

Vaswani et al., Painless Stochastic Gradient, 2019

Nocedal & Wright, Numerical Optimization, 1999

Parameters:
  • max_backtracking_steps (int) – maximum number of iterations for the line-search.

  • slope_rtol (float) – relative tolerance w.r.t. to the slope. The sufficient decrease must be slope_rtol * lr * <grad, updates>, see formula above.

  • decrease_factor (float) – decreasing factor to reduce learning rate.

  • increase_factor (float) – increasing factor to increase learning rate guess. Setting it to 1. amounts to keep the current guess, setting it to math.inf amounts to start with max_learning_rate at each round.

  • max_learning_rate (float) – maximum learning rate (learning rate guess clipped to this).

  • atol (float) – absolute tolerance at which the condition needs to be satisfied.

  • rtol (float) – relative tolerance at which the condition needs to be satisfied.

  • store_grad (bool) – whether to compute and store the gradient at the end of the linesearch. Since the function is called to compute the value to accept the learning rate, we can also access the gradient along the way. By doing that, we can directly reuse the value and the gradient computed at the end of the linesearch for the next iteration using optax.value_and_grad_from_state(). See the example above.

Return type:

GradientTransformationExtraArgs

Returns:

A GradientTransformationExtraArgs, where the update function takes the following additional keyword arguments:

  • value: value of the function at the current params.

  • grad: gradient of the function at the current params.

  • value_fn: function returning the value of the function we seek to optimize.

  • **extra_args: additional keyword arguments, if the function needs

    additional arguments such as input data, they should be put there ( see example in this docstrihng).

class optax.ScaleByBacktrackingLinesearchState(learning_rate: float | jax.Array, value: float | jax.Array, grad: base.Updates | None = None)[source]#

State for optax.scale_by_backtracking_linesearch().

learning_rate#

learning rate computed at the end of a round of line-search, used to scale the update.

value#

value of the objective computed at the end of a round of line-search. Can be reused using optax.value_and_grad_from_state().

grad#

gradient of the objective computed at the end of a round of line-search if the line-search is instantiated with store_grad = True. Otherwise it is None. Can be reused using optax.value_and_grad_from_state().

optax.scale_by_belief(b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16)[source]#

Rescale updates according to the AdaBelief algorithm.

References

[Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of variance of grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByBeliefState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#

State for the rescaling by AdaBelief algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_factored_rms(factored=True, decay_rate=0.8, step_offset=0, min_dim_size_to_factor=128, epsilon=1e-30, decay_rate_fn=<function _decay_rate_pow>)[source]#

Scaling by a factored estimate of the gradient rms (as in Adafactor).

This is a so-called “1+epsilon” scaling algorithms, that is extremely memory efficient compared to RMSProp/Adam, and has had wide success when applied to large-scale training of attention-based models.

References

[Shazeer et al, 2018](https://arxiv.org/abs/1804.04235)

Parameters:
  • factored (bool) – boolean: whether to use factored second-moment estimates..

  • decay_rate (float) – float: controls second-moment exponential decay schedule.

  • step_offset (int) – for finetuning, one may set this to the starting step-number of the fine tuning phase.

  • min_dim_size_to_factor (int) – only factor accumulator if two array dimensions are at least this size.

  • epsilon (float) – Regularization constant for squared gradient.

  • decay_rate_fn (Callable[[int, float], Union[Array, ndarray, bool_, number]]) – A function that accepts the current step, the decay rate parameter and controls the schedule for the second momentum. Defaults to the original adafactor’s power decay schedule. One potential shortcoming of the orignal schedule is the fact that second momentum converges to 1, which effectively freezes the second momentum. To prevent this the user can opt for a custom schedule that sets an upper bound for the second momentum, like in [Zhai et al., 2021](https://arxiv.org/abs/2106.04560).

Returns:

the corresponding GradientTransformation.

class optax.FactoredState(count: chex.Array, v_row: chex.ArrayTree, v_col: chex.ArrayTree, v: chex.ArrayTree)[source]#

Overall state of the gradient transformation.

count: chex.Array#

Alias for field number 0

v_row: chex.ArrayTree#

Alias for field number 1

v_col: chex.ArrayTree#

Alias for field number 2

v: chex.ArrayTree#

Alias for field number 3

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_learning_rate(learning_rate, *, flip_sign=True)[source]#

Scale by the (negative) learning rate (either as scalar or as schedule).

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Can either be a scalar or a schedule (i.e. a callable that maps an (int) step to a float).

  • flip_sign (bool) – When set to True (the default) this corresponds to scaling by the negative learning rate.

Return type:

GradientTransformation

Returns:

An optax.GradientTransformation that corresponds to multiplying the gradient with -learning_rate (if flip_sign is True) or with learning_rate (if flip_sign is False).

optax.scale_by_lion(b1=0.9, b2=0.99, mu_dtype=None)[source]#

Rescale updates according to the Lion algorithm.

References

[Chen et al, 2023](https://arxiv.org/abs/2302.06675)

Parameters:
  • b1 (float) – Rate for combining the momentum and the current grad.

  • b2 (float) – Decay rate for the exponentially weighted average of grads.

  • mu_dtype (Union[str, type[Any], dtype, SupportsDType, None]) – Optional dtype to be used for the momentum; if None then the dtype is inferred from `params and updates.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByLionState(count: chex.Array, mu: base.Updates)[source]#

State for the Lion algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_novograd(b1=0.9, b2=0.25, eps=1e-08, eps_root=0.0, weight_decay=0.0, mu_dtype=None)[source]#

Computes NovoGrad updates.

References

[Ginsburg et al, 2019](https://arxiv.org/abs/1905.11286)

Parameters:
  • b1 (float) – A decay rate for the exponentially weighted average of grads.

  • b2 (float) – A decay rate for the exponentially weighted average of squared grads.

  • eps (float) – A term added to the denominator to improve numerical stability.

  • eps_root (float) – A term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • weight_decay (float) – A scalar weight decay rate.

  • mu_dtype (Union[str, type[Any], dtype, SupportsDType, None]) – An optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

class optax.ScaleByNovogradState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#

State for Novograd.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_optimistic_gradient(alpha=1.0, beta=1.0)[source]#

Compute generalized optimistic gradients.

References

[Mokhtari et al, 2019](https://arxiv.org/abs/1901.08511v2)

Parameters:
  • alpha (float) – Coefficient for generalized optimistic gradient descent.

  • beta (float) – Coefficient for negative momentum.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.scale_by_param_block_norm(min_scale=0.001)[source]#

Scale updates for each param block by the norm of that block’s parameters.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters:

min_scale (float) – Minimum scaling factor.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.scale_by_param_block_rms(min_scale=0.001)[source]#

Scale updates by rms of the gradient for each param vector or matrix.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters:

min_scale (float) – Minimum scaling factor.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.scale_by_radam(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, threshold=5.0)[source]#

Rescale updates according to the Rectified Adam algorithm.

References

[Liu et al, 2020](https://arxiv.org/abs/1908.03265)

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • threshold (float) – Threshold for variance tractability.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.scale_by_polyak(f_min=0.0, max_learning_rate=1.0, eps=0.0)[source]#

Scales the update by Polyak’s step-size.

Return type:

GradientTransformationExtraArgs

optax.scale_by_rms(decay=0.9, eps=1e-08, initial_scale=0.0)[source]#

Rescale updates by the root of the exp. moving avg of the square.

WARNING: PyTorch and optax’s RMSprop implementations differ and could impact

performance. In the denominator, optax uses $sqrt{v + epsilon}$ whereas PyTorch uses $sqrt{v} + epsilon$. See google-deepmind/optax#532 for more detail.

References

[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

Parameters:
  • decay (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • initial_scale (float) – Initial value for second moment.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByRmsState(nu: base.Updates)[source]#

State for exponential root mean-squared (RMS)-normalized updates.

nu: base.Updates#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-06, max_step_size=50.0)[source]#

Scale with the Rprop optimizer.

Rprop, short for resillient backpropogation, is a first order variant of gradient descent. It responds only to the sign of the gradient by increasing or decreasing the step size selected per parameter exponentially to speed up convergence and avoid oscillations.

References

PyTorch implementation:

https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html

Riedmiller and Braun, 1993: https://ieeexplore.ieee.org/document/298623 Igel and HĂĽsken, 2003:

Parameters:
  • learning_rate (float) – The initial step size.

  • eta_minus (float) – Multiplicative factor for decreasing step size. This is applied when the gradient changes sign from one step to the next.

  • eta_plus (float) – Multiplicative factor for increasing step size. This is applied when the gradient has the same sign from one step to the next.

  • min_step_size (float) – Minimum allowed step size. Smaller steps will be clipped to this value.

  • max_step_size (float) – Maximum allowed step size. Larger steps will be clipped to this value.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

class optax.ScaleByRpropState(step_sizes, prev_updates)[source]#
step_sizes: base.Updates#

Alias for field number 0

prev_updates: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_rss(initial_accumulator_value=0.1, eps=1e-07)[source]#

Rescale updates by the root of the sum of all squared gradients to date.

References

[Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) [McMahan et al., 2010](https://arxiv.org/abs/1002.4908)

Parameters:
  • initial_accumulator_value (float) – Starting value for accumulators, must be >= 0.

  • eps (float) – A small floating point value to avoid zero denominator.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByRssState(sum_of_squares: base.Updates)[source]#

State holding the sum of gradient squares to date.

sum_of_squares: base.Updates#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_schedule(step_size_fn)[source]#

Scale updates using a custom schedule for the step_size.

Parameters:

step_size_fn (Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]) – A function that takes an update count as input and proposes the step_size to multiply the updates by.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByScheduleState(count: chex.Array)[source]#

Maintains count for scale scheduling.

count: Union[Array, ndarray, bool_, number]#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_sm3(b1=0.9, b2=1.0, eps=1e-08)[source]#

Scale updates by sm3.

References

[Anil et. al 2019](https://arxiv.org/abs/1901.11150)

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleBySM3State(mu: base.Updates, nu: base.Updates)[source]#

State for the SM3 algorithm.

mu: base.Updates#

Alias for field number 0

nu: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_stddev(decay=0.9, eps=1e-08, initial_scale=0.0)[source]#

Rescale updates by the root of the centered exp. moving average of squares.

References

[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

Parameters:
  • decay (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • initial_scale (float) – Initial value for second moment.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByRStdDevState(mu: base.Updates, nu: base.Updates)[source]#

State for centered exponential moving average of squares of updates.

mu: base.Updates#

Alias for field number 0

nu: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.scale_by_trust_ratio(min_norm=0.0, trust_coefficient=1.0, eps=0.0)[source]#

Scale updates by trust ratio.

References

[You et. al 2020](https://arxiv.org/abs/1904.00962)

Parameters:
  • min_norm (float) – Minimum norm for params and gradient norms; by default is zero.

  • trust_coefficient (float) – A multiplier for the trust ratio.

  • eps (float) – Additive constant added to the denominator for numerical stability.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.ScaleByTrustRatioState[source]#

The scale and decay trust ratio transformation is stateless.

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls)#

Create new instance of ScaleByTrustRatioState()

optax.scale_by_yogi(b1=0.9, b2=0.999, eps=0.001, eps_root=0.0, initial_accumulator_value=1e-06)[source]#

Rescale updates according to the Yogi algorithm.

Supports complex numbers, see https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29

References

[Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long

Parameters:
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of variance of grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • initial_accumulator_value (float) – The starting value for accumulators. Only positive values are allowed.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.set_to_zero()[source]#

Stateless transformation that maps input gradients to zero.

The resulting update function, when called, will return a tree of zeros matching the shape of the input gradients. This means that when the updates returned from this transformation are applied to the model parameters, the model parameters will remain unchanged.

This can be used in combination with multi_transform or masked to freeze (i.e. keep fixed) some parts of the tree of model parameters while applying gradient updates to other parts of the tree.

When updates are set to zero inside the same jit-compiled function as the calculation of gradients, optax transformations, and application of updates to parameters, unnecessary computations will in general be dropped.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

optax.stateless(f)[source]#

Creates a stateless transformation from an update-like function.

This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations.

Parameters:

f (Callable[[Updates, Optional[Params]], Updates]) – Update function that takes in updates (e.g. gradients) and parameters and returns updates. The parameters may be None.

Return type:

GradientTransformation

Returns:

An optax.GradientTransformation.

optax.stateless_with_tree_map(f)[source]#

Creates a stateless transformation from an update-like function for arrays.

This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations, just like optax.stateless. In addition, this function will apply the tree_map over update/params for you.

Parameters:

f (Callable[[Union[Array, ndarray, bool_, number], Union[Array, ndarray, bool_, number, None]], Union[Array, ndarray, bool_, number]]) – Update function that takes in an update array (e.g. gradients) and parameter array and returns an update array. The parameter array may be None.

Return type:

GradientTransformation

Returns:

An optax.GradientTransformation.

optax.trace(decay, nesterov=False, accumulator_dtype=None)[source]#

Compute a trace of past updates.

Note: trace and ema have very similar but distinct updates; trace = decay * trace + t, while ema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.

Parameters:
  • decay (float) – Decay rate for the trace of past updates.

  • nesterov (bool) – Whether to use Nesterov momentum.

  • accumulator_dtype (Optional[Any]) – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.TraceState(trace: base.Params)[source]#

Holds an aggregation of past updates.

trace: base.Params#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.update_infinity_moment(updates, moments, decay, eps)[source]#

Compute the exponential moving average of the infinity norm.

optax.update_moment(updates, moments, decay, order)[source]#

Compute the exponential moving average of the order-th moment.

optax.update_moment_per_elem_norm(updates, moments, decay, order)[source]#

Compute the EMA of the order-th moment of the element-wise norm.

optax.with_extra_args_support(tx)[source]#

Wraps a gradient transformation, so that it ignores extra args.

Return type:

GradientTransformationExtraArgs

optax.zero_nans()[source]#

A transformation which replaces NaNs with 0.

The state of the transformation has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update. This state is not used by the transformation internally, but lets users be aware when NaNs have been zeroed out.

Return type:

GradientTransformation

Returns:

A GradientTransformation.

class optax.ZeroNansState(found_nan: Any)[source]#

Contains a tree.

The entry found_nan has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update.

found_nan: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.