Utilities#
General#
|
Scales gradients for the backwards pass. |
|
Alternative to |
Scale gradient#
- optax.scale_gradient(inputs: optax.ArrayTree, scale: jax.typing.ArrayLike) optax.ArrayTree[source]#
Scales gradients for the backwards pass.
- Parameters:
inputs – A nested array.
scale – The scale factor for the gradient on the backwards pass.
- Returns:
An array of the same structure as inputs, with scaled backward gradient.
Value and grad from state#
- optax.value_and_grad_from_state(value_fn: Callable[[...], TypeAliasForwardRef('jax.typing.ArrayLike')]) Callable[[...], tuple[TypeAliasForwardRef('jax.typing.ArrayLike'), TypeAliasForwardRef('optax.Updates')]][source]#
Alternative to
jax.value_and_grad()fetches value, grad from state.Line-search methods such as
optax.scale_by_backtracking_linesearch()require to compute the gradient and objective function at the candidate iterate. This objective value and gradient can be re-used in the next iteration to save some computations using this utility function.- Parameters:
value_fn – function returning a scalar (float or array of dimension 1), amenable to differentiation in jax using
jax.value_and_grad().- Returns:
A callable akin to
jax.value_and_grad()that fetches value and grad from the state if present. If no value or grad are found or if multiple value and grads are found this function raises an error. If a value is found but is infinite or nan, the value and grad are computed usingjax.value_and_grad(). If the gradient found in the state is None, raises an Error.
Examples
>>> import optax >>> import jax.numpy as jnp >>> def fn(x): return jnp.sum(x ** 2) >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02
Numerical Stability#
|
Increments counter by one while avoiding overflow. |
|
Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. |
|
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads. |
Safe increment#
- optax.safe_increment(count: jax.typing.ArrayLike) jax.typing.ArrayLike[source]#
Increments counter by one while avoiding overflow.
Denote
max_val,min_valas the maximum, minimum, possible values for thedtypeofcount. Normallymax_val + 1would overflow tomin_val. This functions ensures that whenmax_valis reached the counter stays atmax_val.- Parameters:
count – a counter to be incremented.
- Returns:
A counter incremented by 1, or
max_valif the maximum value is reached.
Examples
>>> import jax.numpy as jnp >>> import optax >>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32)) Array(2, dtype=int32) >>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32)) Array(2147483647, dtype=int32)
Added in version 0.2.4.
Safe norm#
- optax.safe_norm(x: jax.typing.ArrayLike, min_norm: jax.typing.ArrayLike, ord: int | float | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) Array[source]#
Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.
The gradients of jnp.maximum(jnp.linalg.norm(x), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
- Parameters:
x – jax array.
min_norm – lower bound for the returned norm.
ord – {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional. Order of the norm. inf means numpy’s inf object. The default is None.
axis – {None, int, 2-tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.
keepdims – bool, optional. If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.
- Returns:
The safe norm of the input vector, accounting for correct gradient.
Safe root mean squares#
- optax.safe_root_mean_squares(x: jax.typing.ArrayLike, min_rms: jax.typing.ArrayLike) Array[source]#
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.
The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
- Parameters:
x – jax array.
min_rms – lower bound for the returned norm.
- Returns:
The safe RMS of the input vector, accounting for correct gradient.
Linear Algebra Operators#
|
Computes matrix^(-1/p), where p is a positive integer. |
|
Power iteration algorithm. |
|
Solves the non-negative least squares problem. |
Matrix inverse pth root#
- optax.matrix_inverse_pth_root(matrix: jax.typing.ArrayLike, p: jax.typing.ArrayLike, num_iters: jax.typing.ArrayLike = 100, ridge_epsilon: jax.typing.ArrayLike = 1e-06, error_tolerance: jax.typing.ArrayLike = 1e-06, precision: Precision = Precision.HIGHEST)[source]#
Computes matrix^(-1/p), where p is a positive integer.
This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.
- Parameters:
matrix – the symmetric PSD matrix whose power it to be computed
p – exponent, for p a positive integer.
num_iters – Maximum number of iterations.
ridge_epsilon – Ridge epsilon added to make the matrix positive definite.
error_tolerance – Error indicator, useful for early termination.
precision – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
- Returns:
matrix^(-1/p)
References
- [Functions of Matrices, Theory and Computation,
Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
Power iteration#
- optax.power_iteration(matrix: jax.typing.ArrayLike | Callable[[base.ArrayTree], base.ArrayTree], *, v0: base.ArrayTree | None = None, num_iters: jax.typing.ArrayLike = 100, error_tolerance: jax.typing.ArrayLike = 1e-06, precision: lax.Precision = Precision.HIGHEST, key: base.PRNGKey | None = None) tuple[jax.typing.ArrayLike, base.ArrayTree][source]#
Power iteration algorithm.
This algorithm computes the dominant eigenvalue (i.e. the spectral radius) and its associated eigenvector of a diagonalizable matrix. This matrix can be given as an array or as a callable implementing a matrix-vector product.
- Parameters:
matrix – a square matrix, either as an array or a callable implementing a matrix-vector product.
v0 – initial vector approximating the dominiant eigenvector. If
matrixis an array of size (n, n), v0 must be a vector of size (n,). If insteadmatrixis a callable, then v0 must be a tree with the same structure as the input of this callable. If this argument is None andmatrixis an array, then a random vector sampled from a uniform distribution in [-1, 1] is used as initial vector.num_iters – Number of power iterations.
error_tolerance – Iterative exit condition. The procedure stops when the relative error of the estimate of the dominant eigenvalue is below this threshold.
precision – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
key – random key for the initialization of
v0when not given explicitly. When this argument is None, jax.random.PRNGKey(0) is used.
- Returns:
A pair (eigenvalue, eigenvector), where eigenvalue is the dominant eigenvalue of
matrixand eigenvector is its associated eigenvector.
References
Wikipedia contributors. Power iteration.
Note
If the matrix is not diagonalizable or the dominant eigenvalue is not unique, the algorithm may not converge.
Changed in version 0.2.2:
matrixcan be a callable. Reversed the order of the return parameters, from (eigenvector, eigenvalue) to (eigenvalue, eigenvector).
Non-negative least squares#
- optax.nnls(A: Array, b: Array, iters: int, unroll: int | bool = 1, L: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Solves the non-negative least squares problem.
Minimizes \(\|A x - b\|_2\) subject to \(x \geq 0\).
Uses the fast projected gradient (FPG) algorithm of Polyak 2015.
- Parameters:
A – Input matrix of shape (M, N).
b – Input vector of shape (M,) or matrix of shape (M, K).
iters – Number of iterations to run the algorithm for.
unroll – Unroll parameter passed to lax.scan.
L – An upper bound on the spectral radius of A.mT @ A (optional).
- Returns:
A solution vector of shape (N,) or matrix of shape (N, K).
Examples
>>> from jax import numpy as jnp >>> import optax >>> A = jnp.array([[1., 2.], [3., 4.]]) >>> b = jnp.array([5., 6.]) >>> x = optax.nnls(A, b, 10**3) >>> print(f"{x[0]:.2f}") 0.00 >>> print(f"{x[1]:.2f}") 1.70
References
Roman A. Polyak, Projected gradient method for non-negative least square, 2015
Second Order Optimization#
|
Computes the diagonal of the (observed) Fisher information matrix. |
|
Computes the diagonal hessian of loss at (inputs, targets). |
|
Performs an efficient vector-Hessian (of loss) product. |
Fisher diagonal#
- optax.second_order.fisher_diag(negative_log_likelihood: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#
Computes the diagonal of the (observed) Fisher information matrix.
- Parameters:
negative_log_likelihood – the negative log likelihood function with expected signature loss = fn(params, inputs, targets).
params – model parameters.
inputs – inputs at which negative_log_likelihood is evaluated.
targets – targets at which negative_log_likelihood is evaluated.
- Returns:
An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).
Hessian diagonal#
- optax.second_order.hessian_diag(loss: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#
Computes the diagonal hessian of loss at (inputs, targets).
- Parameters:
loss – the loss function.
params – model parameters.
inputs – inputs at which loss is evaluated.
targets – targets at which loss is evaluated.
- Returns:
A DeviceArray corresponding to the product to the Hessian of loss evaluated at (params, inputs, targets).
Hessian vector product#
- optax.second_order.hvp(loss: LossFn, v: Array, params: Any, inputs: Array, targets: Array) Array[source]#
Performs an efficient vector-Hessian (of loss) product.
- Parameters:
loss – the loss function.
v – a vector of size ravel(params).
params – model parameters.
inputs – inputs at which loss is evaluated.
targets – targets at which loss is evaluated.
- Returns:
An Array corresponding to the product of v and the Hessian of loss evaluated at (params, inputs, targets).
Tree#
|
KeyType for a NamedTuple in a tree. |
|
Add two (or more) pytrees. |
|
Add two trees, where the second tree is scaled by a scalar. |
|
Check whether two trees are element-wise approximately equal within a tolerance. |
|
Add leading batch dimensions to each leaf of a pytree. |
|
Cast tree to given dtype, skip if None. |
|
Cast tree to dtypes of other_tree. |
|
Creates an identical tree where all tensors are clipped to [min, max]. |
|
Compute the conjugate of a pytree. |
|
Divide two pytrees. |
|
Fetch dtype of tree. |
|
Creates an identical tree where all tensors are filled with |
|
Extract a value from a pytree matching a given key. |
|
Extract values of a pytree matching a given key. |
|
Compute the vector norm of the given ord of a pytree. |
|
Apply a callable over all params in the given optimizer state. |
|
Compute the max of all the elements in a pytree. |
|
Compute the min of all the elements in a pytree. |
|
Multiply two pytrees. |
|
Creates an all-ones tree with the same structure. |
|
Create tree with random entries of the same shape as target tree. |
|
Compute the real part of a pytree. |
|
Split keys to match structure of target tree. |
|
Multiply a tree by a scalar. |
|
Creates a copy of tree with some values replaced as specified by kwargs. |
|
Total size of a pytree. |
|
Subtract two pytrees. |
|
Compute the sum of all the elements in a pytree. |
|
Compute the inner product between two pytrees. |
|
Select tree_x values if condition is true else tree_y values. |
|
Creates an all-zeros tree with the same structure. |
NamedTupleKey#
- class optax.tree_utils.NamedTupleKey(tuple_name: str, name: str)[source]#
KeyType for a NamedTuple in a tree.
When using a function
filtering(path: KeyPath, value: Any) -> bool: ...in a tree inoptax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(), oroptax.tree_utils.tree_set(), can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.- tuple_name#
name of the tuple containing the key.
- Type:
str
- name#
name of the key.
- Type:
str
See also
jax.tree_util.DictKey,jax.tree_util.FlattenedIndexKey,jax.tree_util.GetAttrKey,jax.tree_util.SequenceKey,optax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(),optax.tree_utils.tree_set(),Added in version 0.2.2.
Tree add#
- optax.tree_utils.tree_add(tree_x: Any, tree_y: Any, *other_trees: Any) Any[source]#
Add two (or more) pytrees.
- Parameters:
tree_x – first pytree.
tree_y – second pytree.
*other_trees – optional other trees to add
- Returns:
the sum of the two (or more) pytrees.
Changed in version 0.2.1: Added optional
*other_treesargument.
Tree add and scalar multiply#
- optax.tree_utils.tree_add_scale(tree_x: Any, scalar: jax.typing.ArrayLike, tree_y: Any) Any[source]#
Add two trees, where the second tree is scaled by a scalar.
In infix notation, the function performs
out = tree_x + scalar * tree_y.- Parameters:
tree_x – first pytree.
scalar – scalar value.
tree_y – second pytree.
- Returns:
a pytree with the same structure as
tree_xandtree_y.
Tree all close#
- optax.tree_utils.tree_allclose(a: Any, b: Any, rtol: jax.typing.ArrayLike = 1e-05, atol: jax.typing.ArrayLike = 1e-08, equal_nan: bool = False)[source]#
Check whether two trees are element-wise approximately equal within a tolerance.
See
jax.numpy.allclose()for the equivalent on arrays.- Parameters:
a – a tree
b – a tree
rtol – relative tolerance used for approximate equality
atol – absolute tolerance used for approximate equality
equal_nan – boolean indicating whether NaNs are treated as equal
- Returns:
a boolean value.
Tree batch reshaping#
Tree cast#
- optax.tree_utils.tree_cast(tree: optax.ArrayTree, dtype: str | type[Any] | dtype | SupportsDType | None) optax.ArrayTree[source]#
Cast tree to given dtype, skip if None.
- Parameters:
tree – the tree to cast.
dtype – the dtype to cast to, or None to skip.
- Returns:
the tree, with leaves cast to dtype.
Examples
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16) {'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}
Tree cast like#
- optax.tree_utils.tree_cast_like(tree: T, other_tree: optax.ArrayTree) T[source]#
Cast tree to dtypes of other_tree.
- Parameters:
tree – the tree to cast.
other_tree – reference array tree to use to cast to dtypes of leaves
- Returns:
the tree, with leaves cast to dtype.
Examples
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> other_tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.bfloat16)} >>> optax.tree_utils.tree_cast_like(tree, other_tree) {'a': {'b': Array(1., dtype=float32)}, 'c': Array(2, dtype=bfloat16)}
Tree clip#
- optax.tree_utils.tree_clip(tree: Any, min_value: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, max_value: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Any[source]#
Creates an identical tree where all tensors are clipped to [min, max].
- Parameters:
tree – pytree.
min_value – optional minimal value to clip all tensors to. If
None(default) then result will not be clipped to any minimum value.max_value – optional maximal value to clip all tensors to. If
None(default) then result will not be clipped to any maximum value.
- Returns:
a tree with the same structure as
tree.
Added in version 0.2.3.
Tree conjugate#
Tree data type#
- optax.tree_utils.tree_dtype(tree: optax.ArrayTree, mixed_dtype_handler: str | None = None) str | type[Any] | dtype | SupportsDType[source]#
Fetch dtype of tree.
If the tree is empty, returns the default dtype of JAX arrays.
- Parameters:
tree – the tree to fetch the dtype of.
mixed_dtype_handler – how to handle mixed dtypes in the tree. - If
mixed_dtype_handler=None, returns the common dtype of the leaves of the tree if it exists, otherwise raises an error. - Ifmixed_dtype_handler='promote', promotes the dtypes of the leaves of the tree to a common promoted dtype usingjax.numpy.promote_types(). - Ifmixed_dtype_handler='highest'ormixed_dtype_handler='lowest', returns the highest/lowest dtype of the leaves of the tree. We consider a partial ordering of dtypes asdtype1 <= dtype2ifdtype1is promoted todtype2, that is, ifjax.numpy.promote_types(dtype1, dtype2) == dtype2. Since some dtypes cannot be promoted to one another, this is not a total ordering, and the ‘highest’ or ‘lowest’ options may not be applicable. These options will throw an error if the dtypes of the leaves of the tree cannot be promoted to one another.
- Returns:
the dtype of the tree.
- Raises:
ValueError – If
mixed_dtype_handleris set toNoneand multiple dtypes are found in the tree.ValueError – If
mixed_dtype_handleris set to'highest'or'lowest'and some leaves’ dtypes in the tree cannot be promoted to one another.
Examples
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_dtype(tree) dtype('float32') >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_dtype(tree, 'lowest') dtype('float16') >>> optax.tree_utils.tree_dtype(tree, 'highest') dtype('float32') >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)}, ... 'c': jnp.array(2.0, dtype=jnp.uint32)} >>> # optax.tree_utils.tree_dtype(tree, 'highest') >>> # -> will throw an error because int32 and uint32 >>> # cannot be promoted to one another. >>> optax.tree_utils.tree_dtype(tree, 'promote') dtype('int32')
Added in version 0.2.4.
Tree full like#
- optax.tree_utils.tree_full_like(tree: Any, fill_value: jax.typing.ArrayLike, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#
Creates an identical tree where all tensors are filled with
fill_value.- Parameters:
tree – pytree.
fill_value – the fill value for all tensors in the tree.
dtype – optional dtype to use for the tensors in the tree.
- Returns:
an tree with the same structure as
tree.
Tree divide#
Fetch single value that match a given key#
- optax.tree_utils.tree_get(tree: optax.PyTree, key: Any, default: Any | None = None, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None) Any[source]#
Extract a value from a pytree matching a given key.
Search in the
treefor a specifickey(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).If the
treedoes not containkeyreturnsdefault.Raises a
KeyErrorif multiple values ofkeyare found intree.Generally, you may first get all pairs
(path_to_value, value)for a givenkeyusingoptax.tree_utils.tree_get_all_with_path(). You may then define a filtering operationfiltering(path: Key_Path, value: Any) -> bool: ...that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned byjax.tree_util.tree_leaves_with_path()the paths analyzed by the filtering operation inoptax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(), oroptax.tree_utils.tree_set()detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attributekeyof a named tuple calledMyNamedTuplethe last element of the path will be aoptax.tree_utils.NamedTupleKeycontaining bothname=keyandtuple_name='MyNamedTuple'. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.- Parameters:
tree – tree to search in.
key – keyword or field to search in
treefor.default – default value to return if
keyis not found intree.filtering – optional callable to further filter values in
treethat match thekey.filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match the given key.
- Returns:
- value
value in
treematching the givenkey. If none are found returndefaultvalue. If multiple are found raises an error.
- Raises:
KeyError – If multiple values of
keyare found intree.
Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> count = optax.tree_utils.tree_get(state, 'count') >>> print(count) 0
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> lr = optax.tree_utils.tree_get( ... state, 'learning_rate', filtering=filtering ... ) >>> print(lr) 1.0
Extracting a named tuple by its name
>>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, key=0), ... optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') >>> print(noise_state) AddNoiseState(count=Array(0, dtype=int32), rng_key=Array((), dtype=key<fry>) overlaying: [0 0])
Differentiating between two values by the name of their named tuples.
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, key=0), ... optax.scale_by_adam() ... )
Added in version 0.2.2.
Fetch all values that match a given key#
- optax.tree_utils.tree_get_all_with_path(tree: optax.PyTree, key: Any, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None) list[tuple[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any]][source]#
Extract values of a pytree matching a given key.
Search in a pytree
treefor a specifickey(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).That key/field
keymay appear more than once intree. So this function returns a list of all values corresponding tokeywith the path to that value. The path is a sequence ofKeyEntrythat can be transformed in readable format usingjax.tree_util.keystr(), see the example below.- Parameters:
tree – tree to search in.
key – keyword or field to search in tree for.
filtering – optional callable to further filter values in tree that match the key.
filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match the given key.
- Returns:
- values_with_path
list of tuples where each tuple is of the form (
path_to_value,value). Herevalueis one entry of the tree that corresponds to thekey, andpath_to_valueis a tuple of KeyEntry that is a tuple ofjax.tree_util.DictKey,jax.tree_util.FlattenedIndexKey,jax.tree_util.GetAttrKey,jax.tree_util.SequenceKey, oroptax.tree_utils.NamedTupleKey.
Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate' ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32)) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> filtering = lambda path, value: isinstance(value, tuple) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate', filtering ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Added in version 0.2.2.
Tree norm#
- optax.tree_utils.tree_norm(tree: Any, ord: int | str | float | None = None, squared: bool = False) Array[source]#
Compute the vector norm of the given ord of a pytree.
- Parameters:
tree – pytree.
ord – the order of the vector norm to compute from (None, 1, 2, inf).
squared – whether the norm should be returned squared or not.
- Returns:
a scalar value.
Tree map parameters#
- optax.tree_utils.tree_map_params(initable: Callable[[TypeAliasForwardRef('optax.Params')], TypeAliasForwardRef('optax.OptState')] | Initable, f: Callable[[...], Any], state: optax.OptState, /, *rest: Any, transform_non_params: Callable[[...], Any] | None = None, is_leaf: Callable[[TypeAliasForwardRef('optax.Params')], bool] | None = None) optax.OptState[source]#
Apply a callable over all params in the given optimizer state.
This function exists to help construct partition specs over optimizer states, in the case that a partition spec is already known for the parameters.
For example, the following will replace all optimizer state parameter trees with copies of the given partition spec instead. The argument transform_non_params can be used to replace any remaining fields as required, in this case, we replace those fields by None.
>>> params, specs = jnp.array(0.), jnp.array(0.) # Trees with the same shape >>> opt = optax.sgd(1e-3) >>> state = opt.init(params) >>> opt_specs = optax.tree_map_params( ... opt, ... lambda _, spec: spec, ... state, ... specs, ... transform_non_params=lambda _: None, ... )
- Parameters:
initable – A callable taking parameters and returning an optimizer state, or an object with an init attribute having the same function.
f – A callable that will be applied for all copies of the parameter tree within this optimizer state.
state – The optimizer state to map over.
*rest – Additional arguments, having the same shape as the parameter tree, that will be passed to f.
transform_non_params – An optional function that will be called on all non-parameter fields within the optimizer state.
is_leaf – Passed through to jax.tree.map. This makes it possible to ignore parts of the parameter tree e.g. when the gradient transformations modify the shape of the original pytree, such as for
optax.masked.
- Returns:
The result of applying the function f on all trees in the optimizer’s state that have the same shape as the parameter tree, along with the given optional extra arguments.
Tree max#
Tree min#
Tree multiply#
Tree ones like#
- optax.tree_utils.tree_ones_like(tree: Any, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#
Creates an all-ones tree with the same structure.
- Parameters:
tree – pytree.
dtype – optional dtype to use for the tree of ones.
- Returns:
an all-ones tree with the same structure as
tree.
Split key according to structure of a tree#
Tree with random values#
- optax.tree_utils.tree_random_like(rng_key: base.PRNGKey, target_tree: base.ArrayTree, sampler: Union[Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike], jax.typing.ArrayLike], Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike, jax.sharding.Sharding], jax.typing.ArrayLike]] = <function normal>, dtype: Optional[jax.typing.DTypeLike] = None) base.ArrayTree[source]#
Create tree with random entries of the same shape as target tree.
- Parameters:
rng_key – the key for the random number generator.
target_tree – the tree whose structure to match. Leaves must be arrays.
sampler – the noise sampling function, by default
jax.random.normal.dtype – the desired dtype for the random numbers, passed to
sampler. If None, the dtype of the target tree is used if possible.
- Returns:
a random tree with the same structure as
target_tree, whose leaves have distributionsampler.
Warning
The possible dtypes may be limited by the sampler, for example
jax.random.rademacheronly supports integer dtypes and will raise an error if the dtype of the target tree is not an integer or if the dtype is not of integer type.Added in version 0.2.1.
Tree real part#
Tree scalar multiply#
Set values in a tree#
- optax.tree_utils.tree_set(tree: optax.PyTree, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None, /, **kwargs: Any) optax.PyTree[source]#
Creates a copy of tree with some values replaced as specified by kwargs.
Search in the
treeforkeysin**kwargs(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple). If such a key is found, replace the corresponding value with the one given in**kwargs.Raises a
KeyErrorif some keys in**kwargsare not present in the tree.- Parameters:
tree – pytree whose values are to be replaced.
filtering – optional callable to further filter values in
treethat match the keys to replace.filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match a given key.**kwargs – dictionary of keys with values to replace in
tree.
- Returns:
- new_tree
new pytree with the same structure as
tree. For each element intreewhose key/field matches a key in**kwargs, its value is set by the corresponding value in**kwargs.
- Raises:
KeyError – If no values of some key in
**kwargsare found intreeor none of the values satisfy the filtering operation.
Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> print(state) (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) >>> new_state = optax.tree_utils.tree_set(state, count=2.) >>> print(new_state) (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> print(state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
Note
The recommended usage to inject hyperparameters schedules is through
optax.inject_hyperparams(). This function is a helper for other purposes.Added in version 0.2.2.
Tree size#
Tree subtract#
Tree sum#
- optax.tree_utils.tree_sum(tree: Any, associative_reduction: bool = False) jax.typing.ArrayLike[source]#
Compute the sum of all the elements in a pytree.
- Parameters:
tree – pytree.
associative_reduction – If True, use reduce_associative for a potential compilation time speedup with large pytrees (requires JAX >= 0.6.0). This changes the order of summation which may result in slightly different floating-point values. Default is False.
- Returns:
a scalar value.
Tree inner product#
- optax.tree_utils.tree_vdot(tree_x: Any, tree_y: Any) jax.typing.ArrayLike[source]#
Compute the inner product between two pytrees.
- Parameters:
tree_x – first pytree to use.
tree_y – second pytree to use.
- Returns:
inner product between
tree_xandtree_y, a scalar value.
Examples
>>> optax.tree_utils.tree_vdot( ... {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])}, ... {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])}, ... ) Array(0, dtype=int32)
Note
We upcast the values to the highest precision to avoid numerical issues.
Tree where#
- optax.tree_utils.tree_where(condition, tree_x, tree_y)[source]#
Select tree_x values if condition is true else tree_y values.
- Parameters:
condition – boolean specifying which values to select from tree x or tree_y
tree_x – pytree chosen if condition is True
tree_y – pytree chosen if condition is False
- Returns:
tree_x or tree_y depending on condition.
Tree zeros like#
- optax.tree_utils.tree_zeros_like(tree: Any, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#
Creates an all-zeros tree with the same structure.
- Parameters:
tree – pytree.
dtype – optional dtype to use for the tree of zeros.
- Returns:
an all-zeros tree with the same structure as
tree.