Projections#
Projections can be used to perform constrained optimization. The Euclidean projection onto a set \(\mathcal{C}\) is:
For instance, here is an example how we can project parameters to the non-negative orthant:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)
Available projections#
|
Projection onto box constraints. |
|
Projection onto the (unit) hypercube. |
|
Projection onto the non-negative orthant. |
|
Projection onto a simplex. |
Projection onto a box#
- optax.projections.projection_box(pytree, lower, upper)[source]#
Projection onto box constraints.
\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad \text{lower} \le p \le \text{upper}\]where \(x\) is the input pytree.
- Parameters:
pytree (
Any
) – pytree to project.lower (
Any
) – lower bound, a scalar or pytree with the same structure aspytree
.upper (
Any
) – upper bound, a scalar or pytree with the same structure aspytree
.
- Return type:
Any
- Returns:
projected pytree, with the same structure as
pytree
.
Projection onto a hypercube#
- optax.projections.projection_hypercube(pytree, scale=1.0)[source]#
Projection onto the (unit) hypercube.
\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad 0 \le p \le \text{scale}\]where \(x\) is the input pytree.
By default, we project to the unit hypercube (scale=1.0).
This is a convenience wrapper around
projection_box
.- Parameters:
pytree (
Any
) – pytree to project.scale (
Any
) – scale of the hypercube, a scalar or a pytree (default: 1.0).
- Return type:
Any
- Returns:
projected pytree, with the same structure as
pytree
.
Projection onto the non-negative orthant#
- optax.projections.projection_non_negative(pytree)[source]#
Projection onto the non-negative orthant.
\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad p \ge 0\]where \(x\) is the input pytree.
- Parameters:
pytree (
Any
) – pytree to project.- Return type:
Any
- Returns:
projected pytree, with the same structure as
pytree
.
Projection onto a simplex#
- optax.projections.projection_simplex(pytree, scale=1.0)[source]#
Projection onto a simplex.
This function solves the following constrained optimization problem, where
p
is the input pytree.\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad p \ge 0, p^\top 1 = \text{scale}\]By default, the projection is onto the probability simplex (unit simplex).
- Parameters:
- Return type:
Any
- Returns:
projected pytree, a pytree with the same structure as
pytree
.
Added in version 0.2.3.
Example
Here is an example using a pytree:
>>> import jax.numpy as jnp >>> from optax import tree_utils, projections >>> pytree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> tree_utils.tree_sum(pytree) 6.2 >>> new_pytree = projections.projection_simplex(pytree) >>> tree_utils.tree_sum(new_pytree) 1.0000002