optax.contrib.scale_by_muon#
- optax.contrib.scale_by_muon(ns_coeffs: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike] | tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike], ...] = (3.4445, -4.775, 2.0315), ns_steps: jax.typing.ArrayLike = 5, beta: jax.typing.ArrayLike = 0.95, eps: jax.typing.ArrayLike = 1e-08, mu_dtype: jax.typing.DTypeLike | None = None, *, nesterov: bool = True, adaptive: bool = False, preconditioning: Literal['frobenius', 'spectral', 'aol', 'schatten'] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None) base.GradientTransformation[source]#
Rescale updates according to the Muon algorithm.
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize the momentum accumulated by the optimizer. Mathematically, it does steepest descent under the Schatten-p norm, for some large p. With p=infty, it is equivalent to Shampoo without accumulation, or steepest descent under the Spectral norm.
- Parameters:
ns_coeffs – Coefficients for the Newton-schulz method.
ns_steps – Number of Newton-schulz iterations. Ignored if ns_coeffs is a tuple of tuples.
beta – Decay rate for the exponentially weighted average of grads.
eps – Term added to denominators to improve numerical stability.
mu_dtype – Data type of the momentum accumulator.
nesterov – Whether to use Nesterov momentum.
adaptive – Whether to scale the updates by the dual norm of the original updates. See <https://arxiv.org/abs/2409.20325>
preconditioning – What type of preconditioning to use before NS iterations. Available options are: - ‘frobenius’ (default): Use Frobenius rescaling before NS. - ‘spectral’ : Use Spectral norm rescaling before NS. - ‘aol’: Use AOL rescaling to improve orthogonality. - ‘schatten’: Use the Schatten-4 norm for rescaling.
weight_dimension_numbers – An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning such a tree. None implies that all parameters are 2D matrices.
- Returns:
A GradientTransformation object.
References
Jordan, modded-nanogpt: Speedrunning the NanoGPT baseline, 2024
Bernstein et al., Old Optimizer, New Norm: An Anthology, 2024
Liu et al., Muon is Scalable for LLM Training, <https://arxiv.org/abs/2502.16982>`_, 2025
Boissin et al., Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning, <https://arxiv.org/abs/2512.04632>`_, 2025
Ahn et al., Dion: Distributed Orthonormalized Updates, <https://arxiv.org/abs/2504.05295>`_, 2025
Grishina et al., Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials, <https://arxiv.org/abs/2506.10935>`_, 2025
Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm, <https://arxiv.org/pdf/2505.16932>`, 2025