optax.contrib.DAdaptAdamWState#
- class optax.contrib.DAdaptAdamWState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the GradientTransformation returned by dadapt_adamw.