optax.scale_by_polyak#
- optax.scale_by_polyak(f_min: jax.typing.ArrayLike = 0.0, max_learning_rate: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 0.0, variant: str = 'sps') base.GradientTransformationExtraArgs[source]#
Scales the update by Polyakโs step-size.
See
optax.polyak_sgd()for more details.- Parameters:
f_min โ a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.
max_learning_rate โ a maximum step size to use (defaults to 1).
eps โ a value to add in the denominator of the update (defaults to 0).
variant โ either
'sps'or'sps+'(defaults to'sps').
- Returns:
A
optax.GradientTransformationExtraArgs, where theupdatefunction takes an additional keyword argumentvaluecontaining the current value of the objective function.