optax.scale_by_sm3#
- optax.scale_by_sm3(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 1e-08) optax.GradientTransformation[source]#
Scale updates by sm3.
See
optax.sm3()for more details.- Parameters:
b1 โ Decay rate for the exponentially weighted average of grads.
b2 โ Decay rate for the exponentially weighted average of squared grads.
eps โ Term added to the denominator to improve numerical stability.
- Returns:
A
optax.GradientTransformationobject.