optax.second_order.hessian_diag#
- optax.second_order.hessian_diag(loss: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#
Computes the diagonal hessian of loss at (inputs, targets).
- Parameters:
loss โ the loss function.
params โ model parameters.
inputs โ inputs at which loss is evaluated.
targets โ targets at which loss is evaluated.
- Returns:
A DeviceArray corresponding to the product to the Hessian of loss evaluated at (params, inputs, targets).