optax.ScaleByBeliefState

optax.ScaleByBeliefState#

class optax.ScaleByBeliefState(count: jax.typing.ArrayLike, mu: optax.Updates, nu: optax.Updates)[source]#

State for the rescaling by AdaBelief algorithm.