optax.contrib.dowg#
- optax.contrib.dowg(learning_rate: base.ScalarOrSchedule = 1.0, init_estim_sq_dist: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 0.0001, weight_decay: jax.typing.ArrayLike | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[source]#
Distance over weighted Gradients optimizer.
Examples
>>> import optax >>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dowg() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) Objective function: 13.925367 Objective function: 13.872763 Objective function: 13.775433 Objective function: 13.596172 Objective function: 13.268837
References
Khaled et al., DoWG Unleashed: An Efficient Universal Parameter-Free Gradient Descent Method, 2023.
- Parameters:
learning_rate โ optional learning rate (potentially varying according to some predetermined scheduler).
init_estim_sq_dist โ initial guess of the squared distance to solution.
eps โ small value to prevent division by zero in the denominator defining, the learning rate, also used as initial guess for the distance to solution if
init_estim_sq_distis None.weight_decay โ Strength of the weight decay regularization.
mask โ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.
- Returns:
The corresponding
optax.GradientTransformation.
Added in version 0.2.3.