optax.second_order.hvp

Contents

optax.second_order.hvp#

optax.second_order.hvp(loss: LossFn, v: Array, params: Any, inputs: Array, targets: Array) Array[source]#

Performs an efficient vector-Hessian (of loss) product.

Parameters:
  • loss โ€“ the loss function.

  • v โ€“ a vector of size ravel(params).

  • params โ€“ model parameters.

  • inputs โ€“ inputs at which loss is evaluated.

  • targets โ€“ targets at which loss is evaluated.

Returns:

An Array corresponding to the product of v and the Hessian of loss evaluated at (params, inputs, targets).