optax.tree_utils.tree_where

Contents

optax.tree_utils.tree_where#

optax.tree_utils.tree_where(condition, tree_x, tree_y)[source]#

Select tree_x values if condition is true else tree_y values.

Parameters:
  • condition โ€“ boolean specifying which values to select from tree x or tree_y

  • tree_x โ€“ pytree chosen if condition is True

  • tree_y โ€“ pytree chosen if condition is False

Returns:

tree_x or tree_y depending on condition.