optax.signum#
- optax.signum(learning_rate: base.ScalarOrSchedule, beta: jax.typing.ArrayLike = 0.9, accumulator_dtype: Any | None = None) base.GradientTransformationExtraArgs[source]#
A variant of SGD using signs of the components of an EMA of the gradient.
The update \(u_t\) is defined from the gradients \(g_t\) as:
\[m_t \leftarrow \beta\, m_t + (1 - \beta)\, g_t \\ u_t \leftarrow -\alpha_t\, \text{sign}\,(m_t), \]where \(\alpha_t\) a given learning rate at iteration \(t\), \(m_t\) is EMA of the gradient.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().beta – Exponential moving average decay rate.
accumulator_dtype – Data type for the EMA accumulator.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
References
Bernstein et al., signSGD: Compressed optimization for Non-Convex Problems, 2018
Zhao et al., ‘Deconstructing What Makes a Good Optimizer for Language Models <https://arxiv.org/abs/2407.07972>`_, 2024