optax.projections.projection_l2_sphere#
- optax.projections.projection_l2_sphere(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto the l2 sphere.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_2 = \text{value}\]- Parameters:
tree โ tree to project.
scale โ radius of the sphere.
- Returns:
projected tree, with the same structure as
tree.
Added in version 0.2.4.