optax.tree_utils.tree_allclose#
- optax.tree_utils.tree_allclose(a: Any, b: Any, rtol: jax.typing.ArrayLike = 1e-05, atol: jax.typing.ArrayLike = 1e-08, equal_nan: bool = False)[source]#
Check whether two trees are element-wise approximately equal within a tolerance.
See
jax.numpy.allclose()for the equivalent on arrays.- Parameters:
a โ a tree
b โ a tree
rtol โ relative tolerance used for approximate equality
atol โ absolute tolerance used for approximate equality
equal_nan โ boolean indicating whether NaNs are treated as equal
- Returns:
a boolean value.