optax.matrix_inverse_pth_root

optax.matrix_inverse_pth_root#

optax.matrix_inverse_pth_root(matrix: jax.typing.ArrayLike, p: jax.typing.ArrayLike, num_iters: jax.typing.ArrayLike = 100, ridge_epsilon: jax.typing.ArrayLike = 1e-06, error_tolerance: jax.typing.ArrayLike = 1e-06, precision: Precision = Precision.HIGHEST)[source]#

Computes matrix^(-1/p), where p is a positive integer.

This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.

Parameters:
  • matrix – the symmetric PSD matrix whose power it to be computed

  • p – exponent, for p a positive integer.

  • num_iters – Maximum number of iterations.

  • ridge_epsilon – Ridge epsilon added to make the matrix positive definite.

  • error_tolerance – Error indicator, useful for early termination.

  • precision – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

Returns:

matrix^(-1/p)

References

[Functions of Matrices, Theory and Computation,

Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)