optax.tree_utils.tree_norm

Contents

optax.tree_utils.tree_norm#

optax.tree_utils.tree_norm(tree: Any, ord: int | str | float | None = None, squared: bool = False) Array[source]#

Compute the vector norm of the given ord of a pytree.

Parameters:
  • tree โ€“ pytree.

  • ord โ€“ the order of the vector norm to compute from (None, 1, 2, inf).

  • squared โ€“ whether the norm should be returned squared or not.

Returns:

a scalar value.