optax.scale_by_lion

Contents

optax.scale_by_lion#

optax.scale_by_lion(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.99, mu_dtype: str | type[Any] | dtype | SupportsDType | None = None, *, mode: Literal['hard', 'smooth', 'refined'] = 'hard', smooth_beta: float = 1.0) optax.GradientTransformation[source]#

Rescale updates according to the Lion algorithm.

See optax.lion() for more details.

Parameters:
  • b1 – Rate for combining the momentum and the current grad.

  • b2 – Decay rate for the exponentially weighted average of grads.

  • mu_dtype – Optional dtype to be used for the momentum; if None then the dtype is inferred from `params and updates.

  • mode – Which sign variant to use: “Hard (sign)”, “smooth (tanh smoothing),” or “refined” (linear around 0, saturate to sign for a large value).

  • smooth_beta – Smoothing factor used when mode == “smooth”

Returns:

A optax.GradientTransformation object.