optax.scale_by_polyak

optax.scale_by_polyak#

optax.scale_by_polyak(f_min: jax.typing.ArrayLike = 0.0, max_learning_rate: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 0.0, variant: str = 'sps') base.GradientTransformationExtraArgs[source]#

Scales the update by Polyakโ€™s step-size.

See optax.polyak_sgd() for more details.

Parameters:
  • f_min โ€“ a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.

  • max_learning_rate โ€“ a maximum step size to use (defaults to 1).

  • eps โ€“ a value to add in the denominator of the update (defaults to 0).

  • variant โ€“ either 'sps' or 'sps+' (defaults to 'sps').

Returns:

A optax.GradientTransformationExtraArgs, where the update function takes an additional keyword argument value containing the current value of the objective function.