optax.nnls#
- optax.nnls(A: Array, b: Array, iters: int, unroll: int | bool = 1, L: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Solves the non-negative least squares problem.
Minimizes \(\|A x - b\|_2\) subject to \(x \geq 0\).
Uses the fast projected gradient (FPG) algorithm of Polyak 2015.
- Parameters:
A โ Input matrix of shape (M, N).
b โ Input vector of shape (M,) or matrix of shape (M, K).
iters โ Number of iterations to run the algorithm for.
unroll โ Unroll parameter passed to lax.scan.
L โ An upper bound on the spectral radius of A.mT @ A (optional).
- Returns:
A solution vector of shape (N,) or matrix of shape (N, K).
Examples
>>> from jax import numpy as jnp >>> import optax >>> A = jnp.array([[1., 2.], [3., 4.]]) >>> b = jnp.array([5., 6.]) >>> x = optax.nnls(A, b, 10**3) >>> print(f"{x[0]:.2f}") 0.00 >>> print(f"{x[1]:.2f}") 1.70
References
Roman A. Polyak, Projected gradient method for non-negative least square, 2015