optax.contrib.acprop#
- optax.contrib.acprop(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#
The ACProp optimizer.
Follows the implementation from the original repo in PyTorch: juntang-zhuang/ACProp-Optimizer.
ACProp is an adaptive optimizer that combines centering of second momentum and asynchronous update. For the update at step t, the denominator uses information up to step t-1, while the numerator uses the gradient at step t.
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1,b2,epsandeps_rootrespectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
initfunction of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot g_t / \left(\sqrt{\hat{s}_{t-1}} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Returns:
The corresponding GradientTransformation.
References
Zuhang et al, Momentum Centering and Asynchronous Update for Adaptive Gradient Methods, 2021