optax.sgd

Contents

optax.sgd#

optax.sgd(learning_rate: base.ScalarOrSchedule, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False, accumulator_dtype: Any | None = None) base.GradientTransformationExtraArgs[source]#

A canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and Nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

The canonical stochastic gradient descent returns an update \(u_t\) of the form

\[u_t \leftarrow -\alpha_t g_t, \]

where \(g_t\) is the gradient of the objective (potentially preprocessed by other transformations) and \(\alpha_t\) is the learning_rate at time \(t\) (constant or selected by an optax.Schedule).

Stochastic gradient descent with momentum takes two possible forms.

\[\begin{align*} m_t &\leftarrow g_t + \mu m_{t-1} \\ u_t &\leftarrow \begin{cases} -\alpha_t m_t & \text{ if } \texttt{nesterov = False} \\ -\alpha_t (g_t + \mu m_t) & \text{ if } \texttt{nesterov = True} \end{cases} \\ S_t &\leftarrow m_t, \end{align*}\]

where \(\mu\) is the momentum parameter and \(S_t\) is the state of the optimizer.

Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • momentum โ€“ Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov โ€“ Whether Nesterov momentum is used.

  • accumulator_dtype โ€“ Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.sgd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References

Sutskever et al, On the importance of initialization and momentum in deep learning, 2013