optax.contrib.ProdigyState#
- class optax.contrib.ProdigyState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, params0: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the GradientTransformation returned by prodigy.