optax.second_order.hessian_diag

Contents

optax.second_order.hessian_diag#

optax.second_order.hessian_diag(loss: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#

Computes the diagonal hessian of loss at (inputs, targets).

Parameters:
  • loss โ€“ the loss function.

  • params โ€“ model parameters.

  • inputs โ€“ inputs at which loss is evaluated.

  • targets โ€“ targets at which loss is evaluated.

Returns:

A DeviceArray corresponding to the product to the Hessian of loss evaluated at (params, inputs, targets).