optax.matrix_inverse_pth_root#
- optax.matrix_inverse_pth_root(matrix: jax.typing.ArrayLike, p: jax.typing.ArrayLike, num_iters: jax.typing.ArrayLike = 100, ridge_epsilon: jax.typing.ArrayLike = 1e-06, error_tolerance: jax.typing.ArrayLike = 1e-06, precision: Precision = Precision.HIGHEST)[source]#
Computes matrix^(-1/p), where p is a positive integer.
This function uses the Coupled newton iterations algorithm for the computation of a matrixβs inverse pth root.
- Parameters:
matrix β the symmetric PSD matrix whose power it to be computed
p β exponent, for p a positive integer.
num_iters β Maximum number of iterations.
ridge_epsilon β Ridge epsilon added to make the matrix positive definite.
error_tolerance β Error indicator, useful for early termination.
precision β precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
- Returns:
matrix^(-1/p)
References
- [Functions of Matrices, Theory and Computation,
Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)