optax.scale_by_sm3

Contents

optax.scale_by_sm3#

optax.scale_by_sm3(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 1e-08) optax.GradientTransformation[source]#

Scale updates by sm3.

See optax.sm3() for more details.

Parameters:
  • b1 โ€“ Decay rate for the exponentially weighted average of grads.

  • b2 โ€“ Decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

Returns:

A optax.GradientTransformation object.