Projections#
Projections can be used to perform constrained optimization. The Euclidean projection onto a set \(\mathcal{C}\) is:
For instance, here is an example how we can project parameters to the non-negative orthant:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)
Available projections#
|
Projection onto box constraints. |
|
Projection onto the (unit) hypercube. |
|
Projection onto the l1 ball. |
|
Projection onto the l1 sphere. |
|
Projection onto the l2 ball. |
|
Projection onto the l2 sphere. |
|
Projection onto the l-infinity ball. |
|
Projection onto the non-negative orthant. |
|
Projection onto a simplex. |
|
Projection onto a vector. |
|
Projection onto a hyperplane. |
|
Projection onto a halfspace. |