optax.contrib.scale_by_muon

Contents

optax.contrib.scale_by_muon#

optax.contrib.scale_by_muon(ns_coeffs: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike] | tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike], ...] = (3.4445, -4.775, 2.0315), ns_steps: jax.typing.ArrayLike = 5, beta: jax.typing.ArrayLike = 0.95, eps: jax.typing.ArrayLike = 1e-08, mu_dtype: jax.typing.DTypeLike | None = None, *, nesterov: bool = True, adaptive: bool = False, preconditioning: Literal['frobenius', 'spectral', 'aol', 'schatten'] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None) base.GradientTransformation[source]#

Rescale updates according to the Muon algorithm.

Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize the momentum accumulated by the optimizer. Mathematically, it does steepest descent under the Schatten-p norm, for some large p. With p=infty, it is equivalent to Shampoo without accumulation, or steepest descent under the Spectral norm.

Parameters:
  • ns_coeffs – Coefficients for the Newton-schulz method.

  • ns_steps – Number of Newton-schulz iterations. Ignored if ns_coeffs is a tuple of tuples.

  • beta – Decay rate for the exponentially weighted average of grads.

  • eps – Term added to denominators to improve numerical stability.

  • mu_dtype – Data type of the momentum accumulator.

  • nesterov – Whether to use Nesterov momentum.

  • adaptive – Whether to scale the updates by the dual norm of the original updates. See <https://arxiv.org/abs/2409.20325>

  • preconditioning – What type of preconditioning to use before NS iterations. Available options are: - ‘frobenius’ (default): Use Frobenius rescaling before NS. - ‘spectral’ : Use Spectral norm rescaling before NS. - ‘aol’: Use AOL rescaling to improve orthogonality. - ‘schatten’: Use the Schatten-4 norm for rescaling.

  • weight_dimension_numbers – An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning such a tree. None implies that all parameters are 2D matrices.

Returns:

A GradientTransformation object.

References

Jordan, modded-nanogpt: Speedrunning the NanoGPT baseline, 2024

Bernstein et al., Old Optimizer, New Norm: An Anthology, 2024

Liu et al., Muon is Scalable for LLM Training, <https://arxiv.org/abs/2502.16982>`_, 2025

Boissin et al., Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning, <https://arxiv.org/abs/2512.04632>`_, 2025

Ahn et al., Dion: Distributed Orthonormalized Updates, <https://arxiv.org/abs/2504.05295>`_, 2025

Grishina et al., Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials, <https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm, <https://arxiv.org/pdf/2505.16932>`, 2025