optax.scale_by_lion#
- optax.scale_by_lion(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.99, mu_dtype: str | type[Any] | dtype | SupportsDType | None = None, *, mode: Literal['hard', 'smooth', 'refined'] = 'hard', smooth_beta: float = 1.0) optax.GradientTransformation[source]#
Rescale updates according to the Lion algorithm.
See
optax.lion()for more details.- Parameters:
b1 – Rate for combining the momentum and the current grad.
b2 – Decay rate for the exponentially weighted average of grads.
mu_dtype – Optional dtype to be used for the momentum; if None then the dtype is inferred from `params and updates.
mode – Which sign variant to use: “Hard (sign)”, “smooth (tanh smoothing),” or “refined” (linear around 0, saturate to sign for a large value).
smooth_beta – Smoothing factor used when mode == “smooth”
- Returns:
A
optax.GradientTransformationobject.