optax.contrib.dog

Contents

optax.contrib.dog#

optax.contrib.dog(learning_rate: base.ScalarOrSchedule = 1.0, init_step: tuple[Literal['distance', 'learning_rate', 'heuristic'], jax.typing.ArrayLike] = ('heuristic', 1e-06), eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[source]#

Distance over Gradients (DoG) optimizer.

DoG updates parameters \(x_t\) with stochastic gradients \(g_t\) according to the update rule:

\[\begin{align*} r_t &= \| x_t - x_0 \| \\ \bar{r}_t &= \max_{k \leq t} r_k \\ G_t &= \sum_{k \leq t} \|g_k\|^2 \\ \eta_t &= \frac{\bar{r}_t}{\sqrt{G_t + \epsilon}} \\ x_{t+1} & = x_{t} - \eta_t\, g_t \end{align*}\]
Parameters:
  • learning_rate – optional learning rate (potentially varying according to some predetermined scheduler).

  • init_step – Initial step specification. Consists of a pair (tag, value), where value is a float and tag is a string, which must be one of distance, learning_rate, or heuristic. distance sets the initial distance \(r_0\) (\(r_\epsilon\) in the paper) to the given value. learning_rate sets the initial learning rate \(\eta_0\) to the given value. heuristic sets \(r_0 = \alpha (1 + \|x_0\|)\), where \(\alpha\) is the given value. The suggested value of \(\alpha\) is 1e-6, unless the model uses batch normalization, in which case the suggested value is 1e-4. As discussed in the paper, the value should be small enough to ensure that the initial update step will be small enough to not cause the model to diverge.

  • eps – epsilon used for numerical stability - added to the sum of squared norm of gradients.

  • weight_decay – Strength of the weight decay regularization.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.

Returns:

The corresponding optax.GradientTransformation.

Examples

>>> import optax
>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.dog()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  updates, opt_state = solver.update(
...    grad, opt_state, params, value=value)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: ', f(params))
Objective function:  13.99...
Objective function:  13.99...
Objective function:  13.99...
Objective function:  13.99...
Objective function:  13.99...

References

Ivgi et al., DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size Schedule, 2023.

Added in version 0.2.3.

Warning

The authors recommend using model averaging with this optimizer.

This optimizer’s init function should receive the actual parameters (not just dummy parameters) when the heuristic initial step is used.