optax.contrib.hutchinson_estimator_diag_hessian

optax.contrib.hutchinson_estimator_diag_hessian#

optax.contrib.hutchinson_estimator_diag_hessian(random_seed: Array | None = None)[source]#

Returns a GradientTransformationExtraArgs computing the Hessian diagonal.

The Hessian diagonal is estimated using Hutchinson’s estimator, which is unbiased but has high variance.

Parameters:

random_seed – key used to generate random vectors.

Returns:

GradientTransformationExtraArgs