optax.optimistic_adam_v2

optax.optimistic_adam_v2#

optax.optimistic_adam_v2(learning_rate: base.ScalarOrSchedule, *, alpha: jax.typing.ArrayLike = 1.0, beta: jax.typing.ArrayLike = 1.0, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, nesterov: bool = True) base.GradientTransformationExtraArgs[source]#

The Optimistic Adam optimizer.

This is an optimistic version of the Adam optimizer. It addresses the issue of limit cycling behavior in training Generative Adversarial Networks and other saddle-point min-max problems.

The โ€œ_v2โ€ suffix refers to the re-worked version of the interface (not the algorithm) and will eventually replace the interface of the current optimistic_adam() function.

The algorithm is as follows. First, we define the following parameters:

  • \(learning_rate\): the learning rate.

  • \(\alpha\): the alpha rate in optimistic gradient descent.

  • \(\beta\): the beta rate in optimistic gradient descent.

  • \(\beta_1\) the exponential decay rate for the first moment estimate.

  • \(\beta_2\) the exponential decay rate for the second moment estimate.

Second, we define the following variables:

  • \(g_t\): the incoming gradient.

  • \(m_t\): the biased first moment estimate.

  • \(v_t\): the biased second raw moment estimate.

  • \(\hat{m}_t\): the bias-corrected first moment estimate.

  • \(\hat{v}_t\): the bias-corrected second raw moment estimate.

  • \(r_t\): the signal-to-noise ratio (SNR) vector.

  • \(u_t\): the outgoing update vector.

  • \(S_t\): the state of the optimizer.

Finally, on each iteration, the variables are updated as follows:

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t - 1} + (1 - \beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t - 1} + (1 - \beta_2) \cdot g_t^2 \\ \hat{m}_t &\leftarrow m_t / {(1 - \beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1 - \beta_2^t)} \\ r_t &\leftarrow \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right) \\ u_t &\leftarrow -\alpha_t r_t - o_t (r_t - r_{t - 1}) \\ S_t &\leftarrow (m_t, v_t, r_t). \end{align*}\]
Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • alpha โ€“ One of two scalar optimism parameters in optimistic gradient descent.

  • beta โ€“ One of two scalar optimism parameters in optimistic gradient descent.

  • b1 โ€“ Exponential decay rate to track the first moment of past gradients.

  • b2 โ€“ Exponential decay rate to track the second moment of past gradients.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • eps_root โ€“ Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

  • mu_dtype โ€“ Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov โ€“ Whether to use Nesterov momentum.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> from jax import numpy as jnp, lax
>>> def f(x, y):
...  return x * y  # simple bilinear function
>>> opt = optax.optimistic_adam_v2(1.0, alpha=1e-2, beta=1.0)
>>> def step(state, _):
...  params, opt_state = state
...  distance = jnp.hypot(*params)
...  grads = jax.grad(f, argnums=(0, 1))(*params)
...  grads = grads[0], -grads[1]
...  updates, opt_state = opt.update(grads, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  return (params, opt_state), distance
>>> params = 1.0, 2.0
>>> opt_state = opt.init(params)
>>> _, distances = lax.scan(step, (params, opt_state), length=1025)
>>> for i in range(6):
...  print(f"{distances[4**i]:.3f}")
2.243
2.195
2.161
2.055
0.796
0.001

References

Daskalakis et al, Training GANs with Optimism, 2017