optax.tree_utils.tree_allclose

Contents

optax.tree_utils.tree_allclose#

optax.tree_utils.tree_allclose(a: Any, b: Any, rtol: jax.typing.ArrayLike = 1e-05, atol: jax.typing.ArrayLike = 1e-08, equal_nan: bool = False)[source]#

Check whether two trees are element-wise approximately equal within a tolerance.

See jax.numpy.allclose() for the equivalent on arrays.

Parameters:
  • a โ€“ a tree

  • b โ€“ a tree

  • rtol โ€“ relative tolerance used for approximate equality

  • atol โ€“ absolute tolerance used for approximate equality

  • equal_nan โ€“ boolean indicating whether NaNs are treated as equal

Returns:

a boolean value.