optax.contrib.DAdaptAdamWState

Contents

optax.contrib.DAdaptAdamWState#

class optax.contrib.DAdaptAdamWState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by dadapt_adamw.