optax.tree_utils.tree_clip#
- optax.tree_utils.tree_clip(tree: Any, min_value: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, max_value: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Any[source]#
Creates an identical tree where all tensors are clipped to [min, max].
- Parameters:
tree โ pytree.
min_value โ optional minimal value to clip all tensors to. If
None(default) then result will not be clipped to any minimum value.max_value โ optional maximal value to clip all tensors to. If
None(default) then result will not be clipped to any maximum value.
- Returns:
a tree with the same structure as
tree.
Added in version 0.2.3.