optax.second_order.fisher_diag

Contents

optax.second_order.fisher_diag#

optax.second_order.fisher_diag(negative_log_likelihood: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#

Computes the diagonal of the (observed) Fisher information matrix.

Parameters:
  • negative_log_likelihood โ€“ the negative log likelihood function with expected signature loss = fn(params, inputs, targets).

  • params โ€“ model parameters.

  • inputs โ€“ inputs at which negative_log_likelihood is evaluated.

  • targets โ€“ targets at which negative_log_likelihood is evaluated.

Returns:

An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).