optax.nnls

Contents

optax.nnls#

optax.nnls(A: Array, b: Array, iters: int, unroll: int | bool = 1, L: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Solves the non-negative least squares problem.

Minimizes \(\|A x - b\|_2\) subject to \(x \geq 0\).

Uses the fast projected gradient (FPG) algorithm of Polyak 2015.

Parameters:
  • A โ€“ Input matrix of shape (M, N).

  • b โ€“ Input vector of shape (M,) or matrix of shape (M, K).

  • iters โ€“ Number of iterations to run the algorithm for.

  • unroll โ€“ Unroll parameter passed to lax.scan.

  • L โ€“ An upper bound on the spectral radius of A.mT @ A (optional).

Returns:

A solution vector of shape (N,) or matrix of shape (N, K).

Examples

>>> from jax import numpy as jnp
>>> import optax
>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> b = jnp.array([5., 6.])
>>> x = optax.nnls(A, b, 10**3)
>>> print(f"{x[0]:.2f}")
0.00
>>> print(f"{x[1]:.2f}")
1.70

References

Roman A. Polyak, Projected gradient method for non-negative least square, 2015