optax.scale_by_trust_ratio#
- optax.scale_by_trust_ratio(min_norm: jax.typing.ArrayLike = 0.0, trust_coefficient: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 0.0) optax.GradientTransformation[source]#
Scale updates by trust ratio.
Used in
optax.fromage(),optax.lars(),optax.lamb().- Parameters:
min_norm โ Minimum norm for params and gradient norms; by default is zero.
trust_coefficient โ A multiplier for the trust ratio.
eps โ Additive constant added to the denominator for numerical stability.
- Returns:
A
optax.GradientTransformationobject.