optax.contrib.DoWGState

Contents

optax.contrib.DoWGState#

class optax.contrib.DoWGState(init_params: optax.ArrayTree, weighted_sq_norm_grads: Array, estim_sq_dist: Array)[source]#

State for DoWG optimizer.