optax.second_order.fisher_diag#
- optax.second_order.fisher_diag(negative_log_likelihood: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#
Computes the diagonal of the (observed) Fisher information matrix.
- Parameters:
negative_log_likelihood โ the negative log likelihood function with expected signature loss = fn(params, inputs, targets).
params โ model parameters.
inputs โ inputs at which negative_log_likelihood is evaluated.
targets โ targets at which negative_log_likelihood is evaluated.
- Returns:
An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).