optax.ScaleByAmsgradState# class optax.ScaleByAmsgradState(count: jax.typing.ArrayLike, mu: optax.Updates, nu: optax.Updates, nu_max: optax.Updates)[source]# State for the AMSGrad algorithm.