optax.contrib.ProdigyState

Contents

optax.contrib.ProdigyState#

class optax.contrib.ProdigyState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, params0: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by prodigy.