Projections#

Projections can be used to perform constrained optimization. The Euclidean projection onto a set \(\mathcal{C}\) is:

\[\text{proj}_{\mathcal{C}}(u) := \underset{v}{\text{argmin}} ~ ||u - v||^2_2 \textrm{ subject to } v \in \mathcal{C}.\]

For instance, here is an example how we can project parameters to the non-negative orthant:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)

Available projections#

projection_box(pytree, lower, upper)

Projection onto box constraints.

projection_hypercube(pytree[, scale])

Projection onto the (unit) hypercube.

projection_non_negative(pytree)

Projection onto the non-negative orthant.

projection_simplex(pytree[, scale])

Projection onto a simplex.

Projection onto a box#

optax.projections.projection_box(pytree, lower, upper)[source]#

Projection onto box constraints.

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad \text{lower} \le p \le \text{upper}\]

where \(x\) is the input pytree.

Parameters:
  • pytree (Any) – pytree to project.

  • lower (Any) – lower bound, a scalar or pytree with the same structure as pytree.

  • upper (Any) – upper bound, a scalar or pytree with the same structure as pytree.

Return type:

Any

Returns:

projected pytree, with the same structure as pytree.

Projection onto a hypercube#

optax.projections.projection_hypercube(pytree, scale=1.0)[source]#

Projection onto the (unit) hypercube.

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad 0 \le p \le \text{scale}\]

where \(x\) is the input pytree.

By default, we project to the unit hypercube (scale=1.0).

This is a convenience wrapper around projection_box.

Parameters:
  • pytree (Any) – pytree to project.

  • scale (Any) – scale of the hypercube, a scalar or a pytree (default: 1.0).

Return type:

Any

Returns:

projected pytree, with the same structure as pytree.

Projection onto the non-negative orthant#

optax.projections.projection_non_negative(pytree)[source]#

Projection onto the non-negative orthant.

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad p \ge 0\]

where \(x\) is the input pytree.

Parameters:

pytree (Any) – pytree to project.

Return type:

Any

Returns:

projected pytree, with the same structure as pytree.

Projection onto a simplex#

optax.projections.projection_simplex(pytree, scale=1.0)[source]#

Projection onto a simplex.

This function solves the following constrained optimization problem, where p is the input pytree.

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad p \ge 0, p^\top 1 = \text{scale}\]

By default, the projection is onto the probability simplex (unit simplex).

Parameters:
  • pytree (Any) – pytree to project.

  • scale (Union[Array, ndarray, bool_, number, float, int]) – value the projected pytree should sum to (default: 1.0).

Return type:

Any

Returns:

projected pytree, a pytree with the same structure as pytree.

Added in version 0.2.3.

Example

Here is an example using a pytree:

>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> pytree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_sum(pytree)
6.2
>>> new_pytree = projections.projection_simplex(pytree)
>>> tree_utils.tree_sum(new_pytree)
1.0000002