optax.tree_utils.tree_clip

Contents

optax.tree_utils.tree_clip#

optax.tree_utils.tree_clip(tree: Any, min_value: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, max_value: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Any[source]#

Creates an identical tree where all tensors are clipped to [min, max].

Parameters:
  • tree โ€“ pytree.

  • min_value โ€“ optional minimal value to clip all tensors to. If None (default) then result will not be clipped to any minimum value.

  • max_value โ€“ optional maximal value to clip all tensors to. If None (default) then result will not be clipped to any maximum value.

Returns:

a tree with the same structure as tree.

Added in version 0.2.3.