optax.projections.projection_simplex#
- optax.projections.projection_simplex(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto a simplex.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad p \ge 0, p^\top 1 = \text{scale}\]By default, the projection is onto the probability simplex (unit simplex).
- Parameters:
tree โ tree to project.
scale โ value the projected tree should sum to (default: 1).
- Returns:
projected tree, a tree with the same structure as
tree.
Example
Here is an example using a tree:
>>> import jax.numpy as jnp >>> from optax import tree, projections >>> data = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> print(tree.sum(data)) 6.2 >>> new_data = projections.projection_simplex(data) >>> print(tree.sum(new_data)) 1.0000002
Added in version 0.2.3.