optax.signum

Contents

optax.signum#

optax.signum(learning_rate: base.ScalarOrSchedule, beta: jax.typing.ArrayLike = 0.9, accumulator_dtype: Any | None = None) base.GradientTransformationExtraArgs[source]#

A variant of SGD using signs of the components of an EMA of the gradient.

The update \(u_t\) is defined from the gradients \(g_t\) as:

\[m_t \leftarrow \beta\, m_t + (1 - \beta)\, g_t \\ u_t \leftarrow -\alpha_t\, \text{sign}\,(m_t), \]

where \(\alpha_t\) a given learning rate at iteration \(t\), \(m_t\) is EMA of the gradient.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • beta – Exponential moving average decay rate.

  • accumulator_dtype – Data type for the EMA accumulator.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

References

Bernstein et al., signSGD: Compressed optimization for Non-Convex Problems, 2018

Zhao et al., ‘Deconstructing What Makes a Good Optimizer for Language Models <https://arxiv.org/abs/2407.07972>`_, 2024