optax.tree_utils.tree_batch_shape

optax.tree_utils.tree_batch_shape#

optax.tree_utils.tree_batch_shape(tree: Any, shape: tuple[int, ...] = ())[source]#

Add leading batch dimensions to each leaf of a pytree.

Parameters:
  • tree โ€“ a pytree.

  • shape โ€“ a shape indicating what leading batch dimensions to add.

Returns:

a pytree with the leading batch dimensions added.