Utilities#
General#
|
Scales gradients for the backwards pass. |
|
Alternative to |
Scale gradient#
- optax.scale_gradient(inputs, scale)[source]#
Scales gradients for the backwards pass.
- Parameters:
inputs (chex.ArrayTree) – A nested array.
scale (float) – The scale factor for the gradient on the backwards pass.
- Return type:
chex.ArrayTree
- Returns:
An array of the same structure as inputs, with scaled backward gradient.
Value and grad from state#
- optax.value_and_grad_from_state(value_fn)[source]#
Alternative to
jax.value_and_grad
that fetches value, grad from state.Line-search methods such as
optax.scale_by_backtracking_linesearch()
require to compute the gradient and objective function at the candidate iterate. This objective value and gradient can be re-used in the next iteration to save some computations using this utility function.Examples
>>> import optax >>> import jax.numpy as jnp >>> def fn(x): return jnp.sum(x ** 2) >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02
- Parameters:
value_fn (Callable[…, Union[jax.Array, float]]) – function returning a scalar (float or array of dimension 1), amenable to differentiation in jax using
jax.value_and_grad()
.- Return type:
Callable[…, tuple[Union[float, jax.Array], optax.Updates]]
- Returns:
A callable akin to
jax.value_and_grad()
that fetches value and grad from the state if present. If no value or grad are found or if multiple value and grads are found this function raises an error. If a value is found but is infinite or nan, the value and grad are computed usingjax.value_and_grad()
. If the gradient found in the state is None, raises an Error.
Numerical Stability#
|
Increments int32 counter by one. |
|
Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. |
|
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads. |
Safe int32 increment#
Safe norm#
- optax.safe_norm(x, min_norm, ord=None, axis=None, keepdims=False)[source]#
Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.
The gradients of jnp.maximum(jnp.linalg.norm(x), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
- Parameters:
min_norm (
Union
[Array
,ndarray
,bool_
,number
,float
,int
]) – lower bound for the returned norm.ord (
Union
[int
,float
,str
,None
]) – {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional. Order of the norm. inf means numpy’s inf object. The default is None.axis (
Union
[None
,tuple
[int
,...
],int
]) – {None, int, 2-tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.keepdims (
bool
) – bool, optional. If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.
- Return type:
- Returns:
The safe norm of the input vector, accounting for correct gradient.
Safe root mean squares#
- optax.safe_root_mean_squares(x, min_rms)[source]#
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.
The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
Linear Algebra Operators#
|
Computes matrix^(-1/p), where p is a positive integer. |
|
|
|
Power iteration algorithm. |
Multi normal#
Matrix inverse pth root#
- optax.matrix_inverse_pth_root(matrix, p, num_iters=100, ridge_epsilon=1e-06, error_tolerance=1e-06, precision=Precision.HIGHEST)[source]#
Computes matrix^(-1/p), where p is a positive integer.
This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.
References
- [Functions of Matrices, Theory and Computation,
Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
- Parameters:
matrix (
Union
[Array
,ndarray
,bool_
,number
]) – the symmetric PSD matrix whose power it to be computedp (
int
) – exponent, for p a positive integer.num_iters (
int
) – Maximum number of iterations.ridge_epsilon (
float
) – Ridge epsilon added to make the matrix positive definite.error_tolerance (
float
) – Error indicator, useful for early termination.precision (
Precision
) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
- Returns:
matrix^(-1/p)
Power iteration#
- optax.power_iteration(matrix, *, v0=None, num_iters=100, error_tolerance=1e-06, precision=Precision.HIGHEST, key=None)[source]#
Power iteration algorithm.
This algorithm computes the dominant eigenvalue and its associated eigenvector of a diagonalizable matrix. This matrix can be given as an array or as a callable implementing a matrix-vector product.
References
Wikipedia contributors. Power iteration.
- Parameters:
matrix (Union[chex.Array, Callable[[chex.ArrayTree], chex.ArrayTree]]) – a square matrix, either as an array or a callable implementing a matrix-vector product.
v0 (Optional[chex.ArrayTree]) – initial vector approximating the dominiant eigenvector. If
matrix
is an array of size (n, n), v0 must be a vector of size (n,). If insteadmatrix
is a callable, then v0 must be a tree with the same structure as the input of this callable. If this argument is None andmatrix
is an array, then a random vector sampled from a uniform distribution in [-1, 1] is used as initial vector.num_iters (int) – Number of power iterations.
error_tolerance (float) – Iterative exit condition. The procedure stops when the relative error of the estimate of the dominant eigenvalue is below this threshold.
precision (lax.Precision) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
key (Optional[chex.PRNGKey]) – random key for the initialization of
v0
when not given explicitly. When this argument is None, jax.random.PRNGKey(0) is used.
- Return type:
tuple[chex.Numeric, chex.ArrayTree]
- Returns:
A pair (eigenvalue, eigenvector), where eigenvalue is the dominant eigenvalue of
matrix
and eigenvector is its associated eigenvector.
Changed in version 0.2.2:
matrix
can be a callable. Reversed the order of the return parameters, from (eigenvector, eigenvalue) to (eigenvalue, eigenvector).
Second Order Optimization#
|
Computes the diagonal of the (observed) Fisher information matrix. |
|
Computes the diagonal hessian of loss at (inputs, targets). |
|
Performs an efficient vector-Hessian (of loss) product. |
Fisher diagonal#
- optax.second_order.fisher_diag(negative_log_likelihood, params, inputs, targets)[source]#
Computes the diagonal of the (observed) Fisher information matrix.
- Parameters:
negative_log_likelihood (
LossFn
) – the negative log likelihood function with expected signature loss = fn(params, inputs, targets).params (
Any
) – model parameters.inputs (
Array
) – inputs at which negative_log_likelihood is evaluated.targets (
Array
) – targets at which negative_log_likelihood is evaluated.
- Return type:
- Returns:
An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).
Hessian diagonal#
Hessian vector product#
Tree#
|
KeyType for a NamedTuple in a tree. |
|
Add two (or more) pytrees. |
|
Add two trees, where the second tree is scaled by a scalar. |
|
Divide two pytrees. |
|
Extract a value from a pytree matching a given key. |
|
Extract values of a pytree matching a given key. |
|
Compute the l1 norm of a pytree. |
|
Compute the l2 norm of a pytree. |
|
Apply a callable over all params in the given optimizer state. |
|
Multiply two pytrees. |
|
Creates an all-ones tree with the same structure. |
|
Create tree with random entries of the same shape as target tree. |
|
Multiply a tree by a scalar. |
|
Creates a copy of tree with some values replaced as specified by kwargs. |
|
Subtract two pytrees. |
|
Compute the sum of all the elements in a pytree. |
|
Compute the inner product between two pytrees. |
|
Creates an all-zeros tree with the same structure. |
NamedTupleKey#
- class optax.tree_utils.NamedTupleKey(tuple_name, name)[source]#
KeyType for a NamedTuple in a tree.
When using a function
filtering(path: KeyPath, value: Any) -> bool: ...
in a tree inoptax.tree_utils.tree_get_all_with_path()
,optax.tree_utils.tree_get()
, oroptax.tree_utils.tree_set()
, can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.See also
jax.tree_util.DictKey
,jax.tree_util.FlattenedIndexKey
,jax.tree_util.GetAttrKey
,jax.tree_util.SequenceKey
,optax.tree_utils.tree_get_all_with_path()
,optax.tree_utils.tree_get()
,optax.tree_utils.tree_set()
,- tuple_name#
name of the tuple containing the key.
- Type:
str
- name#
name of the key.
- Type:
str
Added in version 0.2.2.
Tree add#
- optax.tree_utils.tree_add(tree_x, tree_y, *other_trees)[source]#
Add two (or more) pytrees.
- Parameters:
tree_x (
Any
) – first pytree.tree_y (
Any
) – second pytree.*other_trees (
Any
) – optional other trees to add
- Return type:
Any
- Returns:
the sum of the two (or more) pytrees.
Changed in version 0.2.1: Added optional
*other_trees
argument.
Tree add and scalar multiply#
- optax.tree_utils.tree_add_scalar_mul(tree_x, scalar, tree_y)[source]#
Add two trees, where the second tree is scaled by a scalar.
In infix notation, the function performs
out = tree_x + scalar * tree_y
.- Parameters:
tree_x (
Any
) – first pytree.scalar (
Union
[float
,Array
]) – scalar value.tree_y (
Any
) – second pytree.
- Return type:
Any
- Returns:
a pytree with the same structure as
tree_x
andtree_y
.
Tree divide#
Fetch single value that match a given key#
- optax.tree_utils.tree_get(tree, key, default=None, filtering=None)[source]#
Extract a value from a pytree matching a given key.
Search in the
tree
for a specifickey
(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).If the
tree
does not containtkey
returnsdefault
.Raises a
KeyError
if multiple values ofkey
are found intree
.Generally, you may first get all pairs
(path_to_value, value)
for a givenkey
usingoptax.tree_utils.tree_get_all_with_path()
. You may then define a filtering operationfiltering(path: Key_Path, value: Any) -> bool: ...
that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned byjax.tree_util.tree_leaves_with_path()
the paths analyzed by the filtering operation inoptax.tree_utils.tree_get_all_with_path()
,optax.tree_utils.tree_get()
, oroptax.tree_utils.tree_set()
detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attributekey
of a named tuple calledMyNamedTuple
the last element of the path will be aoptax.tree_utils.NamedTupleKey
containing bothname=key
andtuple_name='MyNamedTuple'
. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> count = optax.tree_utils.tree_get(state, 'count') >>> print(count) 0
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> lr = optax.tree_utils.tree_get( ... state, 'learning_rate', filtering=filtering ... ) >>> print(lr) 1.0
Extracting a named tuple by its name
>>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') >>> print(noise_state) AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32))
Differentiating between two values by the name of their named tuples.
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> filtering = ( ... lambda p, v: isinstance(p[-1], optax.tree_utils.NamedTupleKey) ... and p[-1].tuple_name == 'ScaleByAdamState' ... ) >>> count = optax.tree_utils.tree_get(state, 'count', filtering=filtering) >>> print(count) 0
- Parameters:
tree (
Any
) – tree to search in.key (
Any
) – keyword or field to search intree
for.default (
Optional
[Any
]) – default value to return ifkey
is not found intree
.filtering (
Optional
[Callable
[[Tuple
[Union
[DictKey
,FlattenedIndexKey
,GetAttrKey
,SequenceKey
,NamedTupleKey
],...
],Any
],bool
]]) – optional callable to further filter values intree
that match thekey
.filtering(path: Key_Path, value: Any) -> bool: ...
takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()
) and the value that match the given key.
- Return type:
Any
- Returns:
- value
value in
tree
matching the givenkey
. If none are found returndefault
value. If multiple are found raises an error.
- Raises:
KeyError – If multiple values of
key
are found intree
.
Added in version 0.2.2.
Fetch all values that match a given key#
- optax.tree_utils.tree_get_all_with_path(tree, key, filtering=None)[source]#
Extract values of a pytree matching a given key.
Search in a pytree
tree
for a specifickey
(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).That key/field
key
may appear more than once intree
. So this function returns a list of all values corresponding tokey
with the path to that value. The path is a sequence ofKeyEntry
that can be transformed in readable format usingjax.tree_util.keystr()
, see the example below.Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate' ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32)) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> filtering = lambda path, value: isinstance(value, tuple) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate', filtering ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
- Parameters:
tree (
Any
) – tree to search in.key (
Any
) – keyword or field to search in tree for.filtering (
Optional
[Callable
[[Tuple
[Union
[DictKey
,FlattenedIndexKey
,GetAttrKey
,SequenceKey
,NamedTupleKey
],...
],Any
],bool
]]) – optional callable to further filter values in tree that match the key.filtering(path: Key_Path, value: Any) -> bool: ...
takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()
) and the value that match the given key.
- Return type:
list
[tuple
[Tuple
[Union
[DictKey
,FlattenedIndexKey
,GetAttrKey
,SequenceKey
,NamedTupleKey
],...
],Any
]]- Returns:
- values_with_path
list of tuples where each tuple is of the form (
path_to_value
,value
). Herevalue
is one entry of the tree that corresponds to thekey
, andpath_to_value
is a tuple of KeyEntry that is a tuple ofjax.tree_util.DictKey
,jax.tree_util.FlattenedIndexKey
,jax.tree_util.GetAttrKey
,jax.tree_util.SequenceKey
, oroptax.tree_utils.NamedTupleKey
.
Added in version 0.2.2.
Tree l1 norm#
Tree l2 norm#
Tree map parameters#
- optax.tree_utils.tree_map_params(initable, f, state, /, *rest, transform_non_params=None, is_leaf=None)[source]#
Apply a callable over all params in the given optimizer state.
This function exists to help construct partition specs over optimizer states, in the case that a partition spec is already known for the parameters.
For example, the following will replace all optimizer state parameter trees with copies of the given partition spec instead. The argument transform_non_params can be used to replace any remaining fields as required, in this case, we replace those fields by None.
>>> params, specs = jnp.array(0.), jnp.array(0.) # Trees with the same shape >>> opt = optax.sgd(1e-3) >>> state = opt.init(params) >>> opt_specs = optax.tree_map_params( ... opt, ... lambda _, spec: spec, ... state, ... specs, ... transform_non_params=lambda _: None, ... )
- Parameters:
initable (Union[Callable[[optax.Params], optax.OptState], Initable]) – A callable taking parameters and returning an optimizer state, or an object with an init attribute having the same function.
f (Callable[…, Any]) – A callable that will be applied for all copies of the parameter tree within this optimizer state.
state (optax.OptState) – The optimizer state to map over.
*rest (Any) – Additional arguments, having the same shape as the parameter tree, that will be passed to f.
transform_non_params (Optional[Callable[…, Any]]) – An optional function that will be called on all non-parameter fields within the optimizer state.
is_leaf (Optional[Callable[[optax.Params], bool]]) – Passed through to jax.tree_map. This makes it possible to ignore parts of the parameter tree e.g. when the gradient transformations modify the shape of the original pytree, such as for
optax.masked
.
- Return type:
optax.OptState
- Returns:
The result of applying the function f on all trees in the optimizer’s state that have the same shape as the parameter tree, along with the given optional extra arguments.
Tree multiply#
Tree ones like#
- optax.tree_utils.tree_ones_like(tree, dtype=None)[source]#
Creates an all-ones tree with the same structure.
- Parameters:
tree (
Any
) – pytree.dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
,None
]) – optional dtype to use for the tree of ones.
- Return type:
Any
- Returns:
an all-ones tree with the same structure as
tree
.
Tree with random values#
- optax.tree_utils.tree_random_like(rng_key, target_tree, sampler=<function normal>)[source]#
Create tree with random entries of the same shape as target tree.
- Parameters:
rng_key (chex.PRNGKey) – the key for the random number generator.
target_tree (chex.ArrayTree) – the tree whose structure to match. Leaves must be arrays.
sampler (Callable[[chex.PRNGKey, base.Shape], chex.Array]) – the noise sampling function, by default
jax.random.normal
.
- Return type:
chex.ArrayTree
- Returns:
a random tree with the same structure as
target_tree
, whose leaves have distributionsampler
.
Added in version 0.2.1.
Tree scalar multiply#
Set values in a tree#
- optax.tree_utils.tree_set(tree, filtering=None, /, **kwargs)[source]#
Creates a copy of tree with some values replaced as specified by kwargs.
Search in the
tree
forkeys
in**kwargs
(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple). If such a key is found, replace the corresponding value with the one given in**kwargs
.Raises a
KeyError
if some keys in**kwargs
are not present in the tree.Note
The recommended usage to inject hyperparameters schedules is through
optax.inject_hyperparams()
. This function is a helper for other purposes.Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> print(state) (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) >>> new_state = optax.tree_utils.tree_set(state, count=2.) >>> print(new_state) (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> print(state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
- Parameters:
tree (
Any
) – pytree whose values are to be replaced.filtering (
Optional
[Callable
[[Tuple
[Union
[DictKey
,FlattenedIndexKey
,GetAttrKey
,SequenceKey
,NamedTupleKey
],...
],Any
],bool
]]) – optional callable to further filter values intree
that match the keys to replace.filtering(path: Key_Path, value: Any) -> bool: ...
takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()
) and the value that match a given key.**kwargs (
Any
) – dictionary of keys with values to replace intree
.
- Return type:
Any
- Returns:
- new_tree
new pytree with the same structure as
tree
. For each element intree
whose key/field matches a key in**kwargs
, its value is set by the corresponding value in**kwargs
.
- Raises:
KeyError – If no values of some key in
**kwargs
are found intree
or none of the values satisfy the filtering operation.
Added in version 0.2.2.
Tree subtract#
Tree sum#
Tree inner product#
- optax.tree_utils.tree_vdot(tree_x, tree_y)[source]#
Compute the inner product between two pytrees.
Examples
>>> optax.tree_utils.tree_vdot( ... {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])}, ... {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])}, ... ) Array(0, dtype=int32)
- Parameters:
tree_x (
Any
) – first pytree to use.tree_y (
Any
) – second pytree to use.
- Return type:
- Returns:
inner product between
tree_x
andtree_y
, a scalar value.
Implementation detail: we upcast the values to the highest precision to avoid numerical issues.
Tree zeros like#
- optax.tree_utils.tree_zeros_like(tree, dtype=None)[source]#
Creates an all-zeros tree with the same structure.
- Parameters:
tree (
Any
) – pytree.dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
,None
]) – optional dtype to use for the tree of zeros.
- Return type:
Any
- Returns:
an all-zeros tree with the same structure as
tree
.