optax.projections.projection_l2_sphere

optax.projections.projection_l2_sphere#

optax.projections.projection_l2_sphere(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#

Projection onto the l2 sphere.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_2 = \text{value}\]
Parameters:
  • tree โ€“ tree to project.

  • scale โ€“ radius of the sphere.

Returns:

projected tree, with the same structure as tree.

Added in version 0.2.4.