optax.tree_utils.tree_batch_shape# optax.tree_utils.tree_batch_shape(tree: Any, shape: tuple[int, ...] = ())[source]# Add leading batch dimensions to each leaf of a pytree. Parameters: tree โ a pytree. shape โ a shape indicating what leading batch dimensions to add. Returns: a pytree with the leading batch dimensions added.