Utilities#

General#

scale_gradient(inputs, scale)

Scales gradients for the backwards pass.

value_and_grad_from_state(value_fn)

Alternative to jax.value_and_grad that fetches value, grad from state.

Scale gradient#

optax.scale_gradient(inputs, scale)[source]#

Scales gradients for the backwards pass.

Parameters:
  • inputs (chex.ArrayTree) – A nested array.

  • scale (float) – The scale factor for the gradient on the backwards pass.

Return type:

chex.ArrayTree

Returns:

An array of the same structure as inputs, with scaled backward gradient.

Value and grad from state#

optax.value_and_grad_from_state(value_fn)[source]#

Alternative to jax.value_and_grad that fetches value, grad from state.

Line-search methods such as optax.scale_by_backtracking_linesearch() require to compute the gradient and objective function at the candidate iterate. This objective value and gradient can be re-used in the next iteration to save some computations using this utility function.

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> def fn(x): return jnp.sum(x ** 2)
>>> solver = optax.chain(
...     optax.sgd(learning_rate=1.),
...     optax.scale_by_backtracking_linesearch(
...         max_backtracking_steps=15, store_grad=True
...     )
... )
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 5.04E+00
Objective function: 1.81E+00
Objective function: 6.53E-01
Objective function: 2.35E-01
Objective function: 8.47E-02
Parameters:

value_fn (Callable[…, Union[jax.Array, float]]) – function returning a scalar (float or array of dimension 1), amenable to differentiation in jax using jax.value_and_grad().

Return type:

Callable[…, tuple[Union[float, jax.Array], optax.Updates]]

Returns:

A callable akin to jax.value_and_grad() that fetches value and grad from the state if present. If no value or grad are found or if multiple value and grads are found this function raises an error. If a value is found but is infinite or nan, the value and grad are computed using jax.value_and_grad(). If the gradient found in the state is None, raises an Error.

Numerical Stability#

safe_int32_increment(count)

Increments int32 counter by one.

safe_norm(x, min_norm[, ord, axis, keepdims])

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.

safe_root_mean_squares(x, min_rms)

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

Safe int32 increment#

optax.safe_int32_increment(count)[source]#

Increments int32 counter by one.

Normally max_int + 1 would overflow to min_int. This functions ensures that when max_int is reached the counter stays at max_int.

Parameters:

count (Union[Array, ndarray, bool_, number, float, int]) – a counter to be incremented.

Return type:

Union[Array, ndarray, bool_, number, float, int]

Returns:

A counter incremented by 1, or max_int if the maximum precision is reached.

Safe norm#

optax.safe_norm(x, min_norm, ord=None, axis=None, keepdims=False)[source]#

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.

The gradients of jnp.maximum(jnp.linalg.norm(x), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters:
  • x (Union[Array, ndarray, bool_, number]) – jax array.

  • min_norm (Union[Array, ndarray, bool_, number, float, int]) – lower bound for the returned norm.

  • ord (Union[int, float, str, None]) – {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional. Order of the norm. inf means numpy’s inf object. The default is None.

  • axis (Union[None, tuple[int, ...], int]) – {None, int, 2-tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.

  • keepdims (bool) – bool, optional. If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

The safe norm of the input vector, accounting for correct gradient.

Safe root mean squares#

optax.safe_root_mean_squares(x, min_rms)[source]#

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters:
  • x (Union[Array, ndarray, bool_, number]) – jax array.

  • min_rms (Union[Array, ndarray, bool_, number, float, int]) – lower bound for the returned norm.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

The safe RMS of the input vector, accounting for correct gradient.

Linear Algebra Operators#

matrix_inverse_pth_root(matrix, p[, ...])

Computes matrix^(-1/p), where p is a positive integer.

multi_normal(loc, log_scale)

rtype:

MultiNormalDiagFromLogScale

power_iteration(matrix, *[, v0, num_iters, ...])

Power iteration algorithm.

Multi normal#

optax.multi_normal(loc, log_scale)[source]#
Return type:

MultiNormalDiagFromLogScale

Matrix inverse pth root#

optax.matrix_inverse_pth_root(matrix, p, num_iters=100, ridge_epsilon=1e-06, error_tolerance=1e-06, precision=Precision.HIGHEST)[source]#

Computes matrix^(-1/p), where p is a positive integer.

This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.

References

[Functions of Matrices, Theory and Computation,

Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)

Parameters:
  • matrix (Union[Array, ndarray, bool_, number]) – the symmetric PSD matrix whose power it to be computed

  • p (int) – exponent, for p a positive integer.

  • num_iters (int) – Maximum number of iterations.

  • ridge_epsilon (float) – Ridge epsilon added to make the matrix positive definite.

  • error_tolerance (float) – Error indicator, useful for early termination.

  • precision (Precision) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

Returns:

matrix^(-1/p)

Power iteration#

optax.power_iteration(matrix, *, v0=None, num_iters=100, error_tolerance=1e-06, precision=Precision.HIGHEST, key=None)[source]#

Power iteration algorithm.

This algorithm computes the dominant eigenvalue and its associated eigenvector of a diagonalizable matrix. This matrix can be given as an array or as a callable implementing a matrix-vector product.

References

Wikipedia contributors. Power iteration.

Parameters:
  • matrix (Union[chex.Array, Callable[[chex.ArrayTree], chex.ArrayTree]]) – a square matrix, either as an array or a callable implementing a matrix-vector product.

  • v0 (Optional[chex.ArrayTree]) – initial vector approximating the dominiant eigenvector. If matrix is an array of size (n, n), v0 must be a vector of size (n,). If instead matrix is a callable, then v0 must be a tree with the same structure as the input of this callable. If this argument is None and matrix is an array, then a random vector sampled from a uniform distribution in [-1, 1] is used as initial vector.

  • num_iters (int) – Number of power iterations.

  • error_tolerance (float) – Iterative exit condition. The procedure stops when the relative error of the estimate of the dominant eigenvalue is below this threshold.

  • precision (lax.Precision) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

  • key (Optional[chex.PRNGKey]) – random key for the initialization of v0 when not given explicitly. When this argument is None, jax.random.PRNGKey(0) is used.

Return type:

tuple[chex.Numeric, chex.ArrayTree]

Returns:

A pair (eigenvalue, eigenvector), where eigenvalue is the dominant eigenvalue of matrix and eigenvector is its associated eigenvector.

Changed in version 0.2.2: matrix can be a callable. Reversed the order of the return parameters, from (eigenvector, eigenvalue) to (eigenvalue, eigenvector).

Second Order Optimization#

fisher_diag(negative_log_likelihood, params, ...)

Computes the diagonal of the (observed) Fisher information matrix.

hessian_diag(loss, params, inputs, targets)

Computes the diagonal hessian of loss at (inputs, targets).

hvp(loss, v, params, inputs, targets)

Performs an efficient vector-Hessian (of loss) product.

Fisher diagonal#

optax.second_order.fisher_diag(negative_log_likelihood, params, inputs, targets)[source]#

Computes the diagonal of the (observed) Fisher information matrix.

Parameters:
  • negative_log_likelihood (LossFn) – the negative log likelihood function with expected signature loss = fn(params, inputs, targets).

  • params (Any) – model parameters.

  • inputs (Array) – inputs at which negative_log_likelihood is evaluated.

  • targets (Array) – targets at which negative_log_likelihood is evaluated.

Return type:

Array

Returns:

An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).

Hessian diagonal#

optax.second_order.hessian_diag(loss, params, inputs, targets)[source]#

Computes the diagonal hessian of loss at (inputs, targets).

Parameters:
  • loss (LossFn) – the loss function.

  • params (Any) – model parameters.

  • inputs (Array) – inputs at which loss is evaluated.

  • targets (Array) – targets at which loss is evaluated.

Return type:

Array

Returns:

A DeviceArray corresponding to the product to the Hessian of loss evaluated at (params, inputs, targets).

Hessian vector product#

optax.second_order.hvp(loss, v, params, inputs, targets)[source]#

Performs an efficient vector-Hessian (of loss) product.

Parameters:
  • loss (LossFn) – the loss function.

  • v (Array) – a vector of size ravel(params).

  • params (Any) – model parameters.

  • inputs (Array) – inputs at which loss is evaluated.

  • targets (Array) – targets at which loss is evaluated.

Return type:

Array

Returns:

An Array corresponding to the product of v and the Hessian of loss evaluated at (params, inputs, targets).

Tree#

NamedTupleKey(tuple_name, name)

KeyType for a NamedTuple in a tree.

tree_add(tree_x, tree_y, *other_trees)

Add two (or more) pytrees.

tree_add_scalar_mul(tree_x, scalar, tree_y)

Add two trees, where the second tree is scaled by a scalar.

tree_div(tree_x, tree_y)

Divide two pytrees.

tree_get(tree, key[, default, filtering])

Extract a value from a pytree matching a given key.

tree_get_all_with_path(tree, key[, filtering])

Extract values of a pytree matching a given key.

tree_l1_norm(tree)

Compute the l1 norm of a pytree.

tree_l2_norm(tree[, squared])

Compute the l2 norm of a pytree.

tree_map_params(initable, f, state, /, *rest)

Apply a callable over all params in the given optimizer state.

tree_mul(tree_x, tree_y)

Multiply two pytrees.

tree_ones_like(tree[, dtype])

Creates an all-ones tree with the same structure.

tree_random_like(rng_key, target_tree[, sampler])

Create tree with random entries of the same shape as target tree.

tree_scalar_mul(scalar, tree)

Multiply a tree by a scalar.

tree_set(tree[, filtering])

Creates a copy of tree with some values replaced as specified by kwargs.

tree_sub(tree_x, tree_y)

Subtract two pytrees.

tree_sum(tree)

Compute the sum of all the elements in a pytree.

tree_vdot(tree_x, tree_y)

Compute the inner product between two pytrees.

tree_zeros_like(tree[, dtype])

Creates an all-zeros tree with the same structure.

NamedTupleKey#

class optax.tree_utils.NamedTupleKey(tuple_name, name)[source]#

KeyType for a NamedTuple in a tree.

When using a function filtering(path: KeyPath, value: Any) -> bool: ... in a tree in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set(), can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.

See also

jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey, jax.tree_util.GetAttrKey, jax.tree_util.SequenceKey, optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), optax.tree_utils.tree_set(),

tuple_name#

name of the tuple containing the key.

Type:

str

name#

name of the key.

Type:

str

Added in version 0.2.2.

Tree add#

optax.tree_utils.tree_add(tree_x, tree_y, *other_trees)[source]#

Add two (or more) pytrees.

Parameters:
  • tree_x (Any) – first pytree.

  • tree_y (Any) – second pytree.

  • *other_trees (Any) – optional other trees to add

Return type:

Any

Returns:

the sum of the two (or more) pytrees.

Changed in version 0.2.1: Added optional *other_trees argument.

Tree add and scalar multiply#

optax.tree_utils.tree_add_scalar_mul(tree_x, scalar, tree_y)[source]#

Add two trees, where the second tree is scaled by a scalar.

In infix notation, the function performs out = tree_x + scalar * tree_y.

Parameters:
  • tree_x (Any) – first pytree.

  • scalar (Union[float, Array]) – scalar value.

  • tree_y (Any) – second pytree.

Return type:

Any

Returns:

a pytree with the same structure as tree_x and tree_y.

Tree divide#

optax.tree_utils.tree_div(tree_x, tree_y)[source]#

Divide two pytrees.

Parameters:
  • tree_x (Any) – first pytree.

  • tree_y (Any) – second pytree.

Return type:

Any

Returns:

the quotient of the two pytrees.

Fetch single value that match a given key#

optax.tree_utils.tree_get(tree, key, default=None, filtering=None)[source]#

Extract a value from a pytree matching a given key.

Search in the tree for a specific key (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).

If the tree does not containt key returns default.

Raises a KeyError if multiple values of key are found in tree.

Generally, you may first get all pairs (path_to_value, value) for a given key using optax.tree_utils.tree_get_all_with_path(). You may then define a filtering operation filtering(path: Key_Path, value: Any) -> bool: ... that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned by jax.tree_util.tree_leaves_with_path() the paths analyzed by the filtering operation in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set() detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attribute key of a named tuple called MyNamedTuple the last element of the path will be a optax.tree_utils.NamedTupleKey containing both name=key and tuple_name='MyNamedTuple'. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> count = optax.tree_utils.tree_get(state, 'count')
>>> print(count)
0

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = opt.init(params)
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> lr = optax.tree_utils.tree_get(
...   state, 'learning_rate', filtering=filtering
... )
>>> print(lr)
1.0

Extracting a named tuple by its name

>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.chain(
...     optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam()
... )
>>> state = opt.init(params)
>>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState')
>>> print(noise_state)
AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32))

Differentiating between two values by the name of their named tuples.

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.chain(
...   optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam()
... )
>>> state = opt.init(params)
>>> filtering = (
...      lambda p, v: isinstance(p[-1], optax.tree_utils.NamedTupleKey)
...      and p[-1].tuple_name == 'ScaleByAdamState'
... )
>>> count = optax.tree_utils.tree_get(state, 'count', filtering=filtering)
>>> print(count)
0
Parameters:
  • tree (Any) – tree to search in.

  • key (Any) – keyword or field to search in tree for.

  • default (Optional[Any]) – default value to return if key is not found in tree.

  • filtering (Optional[Callable[[Tuple[Union[DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey, NamedTupleKey], ...], Any], bool]]) – optional callable to further filter values in tree that match the key. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match the given key.

Return type:

Any

Returns:

value

value in tree matching the given key. If none are found return default value. If multiple are found raises an error.

Raises:

KeyError – If multiple values of key are found in tree.

Added in version 0.2.2.

Fetch all values that match a given key#

optax.tree_utils.tree_get_all_with_path(tree, key, filtering=None)[source]#

Extract values of a pytree matching a given key.

Search in a pytree tree for a specific key (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).

That key/field key may appear more than once in tree. So this function returns a list of all values corresponding to key with the path to that value. The path is a sequence of KeyEntry that can be transformed in readable format using jax.tree_util.keystr(), see the example below.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = solver.init(params)
>>> found_values_with_path = optax.tree_utils.tree_get_all_with_path(
...   state, 'learning_rate'
... )
>>> print(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32))
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = solver.init(params)
>>> filtering = lambda path, value: isinstance(value, tuple)
>>> found_values_with_path = optax.tree_utils.tree_get_all_with_path(
...   state, 'learning_rate', filtering
... )
>>> print(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Parameters:
  • tree (Any) – tree to search in.

  • key (Any) – keyword or field to search in tree for.

  • filtering (Optional[Callable[[Tuple[Union[DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey, NamedTupleKey], ...], Any], bool]]) – optional callable to further filter values in tree that match the key. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match the given key.

Return type:

list[tuple[Tuple[Union[DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey, NamedTupleKey], ...], Any]]

Returns:

values_with_path

list of tuples where each tuple is of the form (path_to_value, value). Here value is one entry of the tree that corresponds to the key, and path_to_value is a tuple of KeyEntry that is a tuple of jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey, jax.tree_util.GetAttrKey, jax.tree_util.SequenceKey, or optax.tree_utils.NamedTupleKey.

Added in version 0.2.2.

Tree l1 norm#

optax.tree_utils.tree_l1_norm(tree)[source]#

Compute the l1 norm of a pytree.

Parameters:

tree (Any) – pytree.

Return type:

Union[Array, ndarray, bool_, number, float, int]

Returns:

a scalar value.

Tree l2 norm#

optax.tree_utils.tree_l2_norm(tree, squared=False)[source]#

Compute the l2 norm of a pytree.

Parameters:
  • tree (Any) – pytree.

  • squared (bool) – whether the norm should be returned squared or not.

Return type:

Union[Array, ndarray, bool_, number, float, int]

Returns:

a scalar value.

Tree map parameters#

optax.tree_utils.tree_map_params(initable, f, state, /, *rest, transform_non_params=None, is_leaf=None)[source]#

Apply a callable over all params in the given optimizer state.

This function exists to help construct partition specs over optimizer states, in the case that a partition spec is already known for the parameters.

For example, the following will replace all optimizer state parameter trees with copies of the given partition spec instead. The argument transform_non_params can be used to replace any remaining fields as required, in this case, we replace those fields by None.

>>> params, specs = jnp.array(0.), jnp.array(0.)  # Trees with the same shape
>>> opt = optax.sgd(1e-3)
>>> state = opt.init(params)
>>> opt_specs = optax.tree_map_params(
...     opt,
...     lambda _, spec: spec,
...     state,
...     specs,
...     transform_non_params=lambda _: None,
...     )
Parameters:
  • initable (Union[Callable[[optax.Params], optax.OptState], Initable]) – A callable taking parameters and returning an optimizer state, or an object with an init attribute having the same function.

  • f (Callable[…, Any]) – A callable that will be applied for all copies of the parameter tree within this optimizer state.

  • state (optax.OptState) – The optimizer state to map over.

  • *rest (Any) – Additional arguments, having the same shape as the parameter tree, that will be passed to f.

  • transform_non_params (Optional[Callable[…, Any]]) – An optional function that will be called on all non-parameter fields within the optimizer state.

  • is_leaf (Optional[Callable[[optax.Params], bool]]) – Passed through to jax.tree_map. This makes it possible to ignore parts of the parameter tree e.g. when the gradient transformations modify the shape of the original pytree, such as for optax.masked.

Return type:

optax.OptState

Returns:

The result of applying the function f on all trees in the optimizer’s state that have the same shape as the parameter tree, along with the given optional extra arguments.

Tree multiply#

optax.tree_utils.tree_mul(tree_x, tree_y)[source]#

Multiply two pytrees.

Parameters:
  • tree_x (Any) – first pytree.

  • tree_y (Any) – second pytree.

Return type:

Any

Returns:

the product of the two pytrees.

Tree ones like#

optax.tree_utils.tree_ones_like(tree, dtype=None)[source]#

Creates an all-ones tree with the same structure.

Parameters:
  • tree (Any) – pytree.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype to use for the tree of ones.

Return type:

Any

Returns:

an all-ones tree with the same structure as tree.

Tree with random values#

optax.tree_utils.tree_random_like(rng_key, target_tree, sampler=<function normal>)[source]#

Create tree with random entries of the same shape as target tree.

Parameters:
  • rng_key (chex.PRNGKey) – the key for the random number generator.

  • target_tree (chex.ArrayTree) – the tree whose structure to match. Leaves must be arrays.

  • sampler (Callable[[chex.PRNGKey, base.Shape], chex.Array]) – the noise sampling function, by default jax.random.normal.

Return type:

chex.ArrayTree

Returns:

a random tree with the same structure as target_tree, whose leaves have distribution sampler.

Added in version 0.2.1.

Tree scalar multiply#

optax.tree_utils.tree_scalar_mul(scalar, tree)[source]#

Multiply a tree by a scalar.

In infix notation, the function performs out = scalar * tree.

Parameters:
  • scalar (Union[float, Array]) – scalar value.

  • tree (Any) – pytree.

Return type:

Any

Returns:

a pytree with the same structure as tree.

Set values in a tree#

optax.tree_utils.tree_set(tree, filtering=None, /, **kwargs)[source]#

Creates a copy of tree with some values replaced as specified by kwargs.

Search in the tree for keys in **kwargs (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple). If such a key is found, replace the corresponding value with the one given in **kwargs.

Raises a KeyError if some keys in **kwargs are not present in the tree.

Note

The recommended usage to inject hyperparameters schedules is through optax.inject_hyperparams(). This function is a helper for other purposes.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> print(state)
(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
>>> new_state = optax.tree_utils.tree_set(state, count=2.)
>>> print(new_state)
(ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...     learning_rate=lambda count: 1/(count+1)
...  )
>>> state = opt.init(params)
>>> print(state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> new_state = optax.tree_utils.tree_set(
...   state, filtering, learning_rate=jnp.asarray(0.1)
... )
>>> print(new_state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
Parameters:
  • tree (Any) – pytree whose values are to be replaced.

  • filtering (Optional[Callable[[Tuple[Union[DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey, NamedTupleKey], ...], Any], bool]]) – optional callable to further filter values in tree that match the keys to replace. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match a given key.

  • **kwargs (Any) – dictionary of keys with values to replace in tree.

Return type:

Any

Returns:

new_tree

new pytree with the same structure as tree. For each element in tree whose key/field matches a key in **kwargs, its value is set by the corresponding value in **kwargs.

Raises:

KeyError – If no values of some key in **kwargs are found in tree or none of the values satisfy the filtering operation.

Added in version 0.2.2.

Tree subtract#

optax.tree_utils.tree_sub(tree_x, tree_y)[source]#

Subtract two pytrees.

Parameters:
  • tree_x (Any) – first pytree.

  • tree_y (Any) – second pytree.

Return type:

Any

Returns:

the difference of the two pytrees.

Tree sum#

optax.tree_utils.tree_sum(tree)[source]#

Compute the sum of all the elements in a pytree.

Parameters:

tree (Any) – pytree.

Return type:

Union[Array, ndarray, bool_, number, float, int]

Returns:

a scalar value.

Tree inner product#

optax.tree_utils.tree_vdot(tree_x, tree_y)[source]#

Compute the inner product between two pytrees.

Examples

>>> optax.tree_utils.tree_vdot(
...   {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])},
...   {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])},
... )
Array(0, dtype=int32)
Parameters:
  • tree_x (Any) – first pytree to use.

  • tree_y (Any) – second pytree to use.

Return type:

Union[Array, ndarray, bool_, number, float, int]

Returns:

inner product between tree_x and tree_y, a scalar value.

Implementation detail: we upcast the values to the highest precision to avoid numerical issues.

Tree zeros like#

optax.tree_utils.tree_zeros_like(tree, dtype=None)[source]#

Creates an all-zeros tree with the same structure.

Parameters:
  • tree (Any) – pytree.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype to use for the tree of zeros.

Return type:

Any

Returns:

an all-zeros tree with the same structure as tree.