optax.tree_utils.tree_where#
- optax.tree_utils.tree_where(condition, tree_x, tree_y)[source]#
Select tree_x values if condition is true else tree_y values.
- Parameters:
condition โ boolean specifying which values to select from tree x or tree_y
tree_x โ pytree chosen if condition is True
tree_y โ pytree chosen if condition is False
- Returns:
tree_x or tree_y depending on condition.