optax.scale_by_trust_ratio

optax.scale_by_trust_ratio#

optax.scale_by_trust_ratio(min_norm: jax.typing.ArrayLike = 0.0, trust_coefficient: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 0.0) optax.GradientTransformation[source]#

Scale updates by trust ratio.

Used in optax.fromage(), optax.lars(), optax.lamb().

Parameters:
  • min_norm โ€“ Minimum norm for params and gradient norms; by default is zero.

  • trust_coefficient โ€“ A multiplier for the trust ratio.

  • eps โ€“ Additive constant added to the denominator for numerical stability.

Returns:

A optax.GradientTransformation object.