optax.contrib.MuonState

Contents

optax.contrib.MuonState#

class optax.contrib.MuonState(count: jax.typing.ArrayLike, mu: optax.Updates, ns_coeffs: jax.typing.ArrayLike)[source]#

State for the Muon algorithm.