Projections#

Projections can be used to perform constrained optimization. The Euclidean projection onto a set \(\mathcal{C}\) is:

\[\text{proj}_{\mathcal{C}}(u) := \underset{v}{\text{argmin}} ~ \|u - v\|^2_2 \textrm{ subject to } v \in \mathcal{C}.\]

For instance, here is an example how we can project parameters to the non-negative orthant:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)

Available projections#

projection_box(tree, lower, upper)

Projection onto box constraints.

projection_hypercube(tree[, scale])

Projection onto the (unit) hypercube.

projection_l1_ball(tree[, scale])

Projection onto the l1 ball.

projection_l1_sphere(tree[, scale])

Projection onto the l1 sphere.

projection_l2_ball(tree[, scale])

Projection onto the l2 ball.

projection_l2_sphere(tree[, scale])

Projection onto the l2 sphere.

projection_linf_ball(tree[, scale])

Projection onto the l-infinity ball.

projection_non_negative(tree)

Projection onto the non-negative orthant.

projection_simplex(tree[, scale])

Projection onto a simplex.

Projection onto a box#

optax.projections.projection_box(tree: Any, lower: Any, upper: Any) Any[source]#

Projection onto box constraints.

\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad \text{lower} \le p \le \text{upper}\]

where \(x\) is the input tree.

Parameters:
  • tree – tree to project.

  • lower – lower bound, a scalar or tree with the same structure as tree.

  • upper – upper bound, a scalar or tree with the same structure as tree.

Returns:

projected tree, with the same structure as tree.

Projection onto a hypercube#

optax.projections.projection_hypercube(tree: Any, scale: Any = 1) Any[source]#

Projection onto the (unit) hypercube.

\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad 0 \le p \le \text{scale}\]

where \(x\) is the input tree.

By default, we project to the unit hypercube (scale=1).

This is a convenience wrapper around projection_box.

Parameters:
  • tree – tree to project.

  • scale – scale of the hypercube, a scalar or a tree (default: 1).

Returns:

projected tree, with the same structure as tree.

Projection onto the L1 ball#

optax.projections.projection_l1_ball(tree: Any, scale: chex.Numeric = 1) Any[source]#

Projection onto the l1 ball.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_1 \le \text{scale}\]
Parameters:
  • tree – tree to project.

  • scale – radius of the ball.

Returns:

projected tree, with the same structure as tree.

Example

>>> import jax.numpy as jnp
>>> from optax import tree, projections
>>> data = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> print(tree.norm(data, ord=1))
6.2
>>> new_data = projections.projection_l1_ball(data)
>>> print(tree.norm(new_data, ord=1))
1.0000002

Added in version 0.2.4.

Projection onto the L1 sphere#

optax.projections.projection_l1_sphere(tree: Any, scale: chex.Numeric = 1) Any[source]#

Projection onto the l1 sphere.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_1 = \text{scale}\]
Parameters:
  • tree – tree to project.

  • scale – radius of the sphere.

Returns:

projected tree, with the same structure as tree.

Projection onto the L2 ball#

optax.projections.projection_l2_ball(tree: Any, scale: chex.Numeric = 1) Any[source]#

Projection onto the l2 ball.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_2 \le \text{scale}\]
Parameters:
  • tree – tree to project.

  • scale – radius of the ball.

Returns:

projected tree, with the same structure as tree.

Added in version 0.2.4.

Projection onto the L2 sphere#

optax.projections.projection_l2_sphere(tree: Any, scale: chex.Numeric = 1) Any[source]#

Projection onto the l2 sphere.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_2 = \text{value}\]
Parameters:
  • tree – tree to project.

  • scale – radius of the sphere.

Returns:

projected tree, with the same structure as tree.

Added in version 0.2.4.

Projection onto the L-infinity ball#

optax.projections.projection_linf_ball(tree: Any, scale: chex.Numeric = 1) Any[source]#

Projection onto the l-infinity ball.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{y}{\text{argmin}} ~ \|x - y\|_2^2 \quad \textrm{subject to} \quad \|y\|_{\infty} \le \text{scale}\]
Parameters:
  • tree – tree to project.

  • scale – radius of the ball.

Returns:

projected tree, with the same structure as tree.

Projection onto the non-negative orthant#

optax.projections.projection_non_negative(tree: Any) Any[source]#

Projection onto the non-negative orthant.

\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad p \ge 0\]

where \(x\) is the input tree.

Parameters:

tree – tree to project.

Returns:

projected tree, with the same structure as tree.

Projection onto a simplex#

optax.projections.projection_simplex(tree: Any, scale: chex.Numeric = 1) Any[source]#

Projection onto a simplex.

This function solves the following constrained optimization problem, where x is the input tree.

\[\underset{p}{\text{argmin}} ~ \|x - p\|_2^2 \quad \textrm{subject to} \quad p \ge 0, p^\top 1 = \text{scale}\]

By default, the projection is onto the probability simplex (unit simplex).

Parameters:
  • tree – tree to project.

  • scale – value the projected tree should sum to (default: 1).

Returns:

projected tree, a tree with the same structure as tree.

Example

Here is an example using a tree:

>>> import jax.numpy as jnp
>>> from optax import tree, projections
>>> data = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> print(tree.sum(data))
6.2
>>> new_data = projections.projection_simplex(data)
>>> print(tree.sum(new_data))
1.0000002

Added in version 0.2.3.