Projections#
Projections can be used to perform constrained optimization. The Euclidean projection onto a set \(\mathcal{C}\) is:
For instance, here is an example how we can project parameters to the non-negative orthant:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)
Available projections#
|
Projection onto box constraints. |
|
Projection onto the (unit) hypercube. |
|
Projection onto the l1 ball. |
|
Projection onto the l1 sphere. |
|
Projection onto the l2 ball. |
|
Projection onto the l2 sphere. |
|
Projection onto the l-infinity ball. |
|
Projection onto the non-negative orthant. |
|
Projection onto a simplex. |
|
Projection onto a vector. |
|
Projection onto a hyperplane. |
|
Projection onto a halfspace. |
Projection onto a box#
- optax.projections.projection_box(tree: Any, lower: Any, upper: Any) Any[source]#
Projection onto box constraints.
\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad \text{lower} \le p \le \text{upper}\]where \(x\) is the input tree.
- Parameters:
tree – tree to project.
lower – lower bound, a scalar or tree with the same structure as
tree.upper – upper bound, a scalar or tree with the same structure as
tree.
- Returns:
projected tree, with the same structure as
tree.
Projection onto a hypercube#
- optax.projections.projection_hypercube(tree: Any, scale: Any = 1) Any[source]#
Projection onto the (unit) hypercube.
\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad 0 \le p \le \text{scale}\]where \(x\) is the input tree.
By default, we project to the unit hypercube (scale=1).
This is a convenience wrapper around
projection_box.- Parameters:
tree – tree to project.
scale – scale of the hypercube, a scalar or a tree (default: 1).
- Returns:
projected tree, with the same structure as
tree.
Projection onto the L1 ball#
- optax.projections.projection_l1_ball(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto the l1 ball.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_1 \le \text{scale}\]- Parameters:
tree – tree to project.
scale – radius of the ball.
- Returns:
projected tree, with the same structure as
tree.
Example
>>> import jax.numpy as jnp >>> from optax import tree, projections >>> data = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> print(tree.norm(data, ord=1)) 6.2 >>> new_data = projections.projection_l1_ball(data) >>> print(tree.norm(new_data, ord=1)) 1.0000002
Added in version 0.2.4.
Projection onto the L1 sphere#
- optax.projections.projection_l1_sphere(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto the l1 sphere.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_1 = \text{scale}\]- Parameters:
tree – tree to project.
scale – radius of the sphere.
- Returns:
projected tree, with the same structure as
tree.
Projection onto the L2 ball#
- optax.projections.projection_l2_ball(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto the l2 ball.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_2 \le \text{scale}\]- Parameters:
tree – tree to project.
scale – radius of the ball.
- Returns:
projected tree, with the same structure as
tree.
Added in version 0.2.4.
Projection onto the L2 sphere#
- optax.projections.projection_l2_sphere(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto the l2 sphere.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_2 = \text{value}\]- Parameters:
tree – tree to project.
scale – radius of the sphere.
- Returns:
projected tree, with the same structure as
tree.
Added in version 0.2.4.
Projection onto the L-infinity ball#
- optax.projections.projection_linf_ball(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto the l-infinity ball.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_{\infty} \le \text{scale}\]- Parameters:
tree – tree to project.
scale – radius of the ball.
- Returns:
projected tree, with the same structure as
tree.
Projection onto the non-negative orthant#
- optax.projections.projection_non_negative(tree: Any) Any[source]#
Projection onto the non-negative orthant.
\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad p \ge 0\]where \(x\) is the input tree.
- Parameters:
tree – tree to project.
- Returns:
projected tree, with the same structure as
tree.
Projection onto a simplex#
- optax.projections.projection_simplex(tree: Any, scale: jax.typing.ArrayLike = 1) Any[source]#
Projection onto a simplex.
This function solves the following constrained optimization problem, where
xis the input tree.\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad p \ge 0, p^\top 1 = \text{scale}\]By default, the projection is onto the probability simplex (unit simplex).
- Parameters:
tree – tree to project.
scale – value the projected tree should sum to (default: 1).
- Returns:
projected tree, a tree with the same structure as
tree.
Example
Here is an example using a tree:
>>> import jax.numpy as jnp >>> from optax import tree, projections >>> data = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> print(tree.sum(data)) 6.2 >>> new_data = projections.projection_simplex(data) >>> print(tree.sum(new_data)) 1.0000002
Added in version 0.2.3.
Projection onto a vector#
- optax.projections.projection_vector(x: Any, a: Any) Any[source]#
Projection onto a vector.
Projects a tree
xonto the vector defined by a treea:\[\operatorname{proj}_a x = \frac{\langle x, a \rangle}{\langle a, a \rangle} a\]- Parameters:
x – tree to project.
a – tree onto which to project. Must have the same structure as
x.
- Returns:
tree with the same structure as
x.
Projection onto a hyperplane#
- optax.projections.projection_hyperplane(x: Any, a: Any, b: jax.typing.ArrayLike) Any[source]#
Projection onto a hyperplane.
Projects a tree
xonto the hyperplane defined by a treeaand scalarb.\[\operatorname{argmin}_y \|x - y\|_2^2 \quad \text{subject to} \quad \langle a, y \rangle = b\]- Parameters:
x – tree to project.
a – tree defining hyperplane onto which to project. Must have the same structure as
x.b – scalar defining hyperplane onto which to project.
- Returns:
tree with the same structure as
x.
Projection onto a halfspace#
- optax.projections.projection_halfspace(x: Any, a: Any, b: jax.typing.ArrayLike) Any[source]#
Projection onto a halfspace.
Projects a tree
xonto the halfspace defined by a treeaand scalarb.\[\operatorname{argmin}_y \|x - y\|_2^2 \quad \text{subject to} \quad \langle a, y \rangle \leq b\]- Parameters:
x – tree to project.
a – tree defining halfspace onto which to project. Must have the same structure as
x.b – scalar defining halfspace onto which to project.
- Returns:
tree with the same structure as
x.