optax.optimistic_adam_v2#
- optax.optimistic_adam_v2(learning_rate: base.ScalarOrSchedule, *, alpha: jax.typing.ArrayLike = 1.0, beta: jax.typing.ArrayLike = 1.0, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, nesterov: bool = True) base.GradientTransformationExtraArgs[source]#
The Optimistic Adam optimizer.
This is an optimistic version of the Adam optimizer. It addresses the issue of limit cycling behavior in training Generative Adversarial Networks and other saddle-point min-max problems.
The โ_v2โ suffix refers to the re-worked version of the interface (not the algorithm) and will eventually replace the interface of the current
optimistic_adam()function.The algorithm is as follows. First, we define the following parameters:
\(learning_rate\): the learning rate.
\(\alpha\): the alpha rate in optimistic gradient descent.
\(\beta\): the beta rate in optimistic gradient descent.
\(\beta_1\) the exponential decay rate for the first moment estimate.
\(\beta_2\) the exponential decay rate for the second moment estimate.
Second, we define the following variables:
\(g_t\): the incoming gradient.
\(m_t\): the biased first moment estimate.
\(v_t\): the biased second raw moment estimate.
\(\hat{m}_t\): the bias-corrected first moment estimate.
\(\hat{v}_t\): the bias-corrected second raw moment estimate.
\(r_t\): the signal-to-noise ratio (SNR) vector.
\(u_t\): the outgoing update vector.
\(S_t\): the state of the optimizer.
Finally, on each iteration, the variables are updated as follows:
\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t - 1} + (1 - \beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t - 1} + (1 - \beta_2) \cdot g_t^2 \\ \hat{m}_t &\leftarrow m_t / {(1 - \beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1 - \beta_2^t)} \\ r_t &\leftarrow \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right) \\ u_t &\leftarrow -\alpha_t r_t - o_t (r_t - r_{t - 1}) \\ S_t &\leftarrow (m_t, v_t, r_t). \end{align*}\]- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().alpha โ One of two scalar optimism parameters in optimistic gradient descent.
beta โ One of two scalar optimism parameters in optimistic gradient descent.
b1 โ Exponential decay rate to track the first moment of past gradients.
b2 โ Exponential decay rate to track the second moment of past gradients.
eps โ Term added to the denominator to improve numerical stability.
eps_root โ Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.
mu_dtype โ Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
nesterov โ Whether to use Nesterov momentum.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> from jax import numpy as jnp, lax >>> def f(x, y): ... return x * y # simple bilinear function >>> opt = optax.optimistic_adam_v2(1.0, alpha=1e-2, beta=1.0) >>> def step(state, _): ... params, opt_state = state ... distance = jnp.hypot(*params) ... grads = jax.grad(f, argnums=(0, 1))(*params) ... grads = grads[0], -grads[1] ... updates, opt_state = opt.update(grads, opt_state, params) ... params = optax.apply_updates(params, updates) ... return (params, opt_state), distance >>> params = 1.0, 2.0 >>> opt_state = opt.init(params) >>> _, distances = lax.scan(step, (params, opt_state), length=1025) >>> for i in range(6): ... print(f"{distances[4**i]:.3f}") 2.243 2.195 2.161 2.055 0.796 0.001
References
Daskalakis et al, Training GANs with Optimism, 2017