optax.sgd#
- optax.sgd(learning_rate: base.ScalarOrSchedule, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False, accumulator_dtype: Any | None = None) base.GradientTransformationExtraArgs[source]#
A canonical Stochastic Gradient Descent optimizer.
This implements stochastic gradient descent. It also includes support for momentum, and Nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.
The canonical stochastic gradient descent returns an update \(u_t\) of the form
\[u_t \leftarrow -\alpha_t g_t, \]where \(g_t\) is the gradient of the objective (potentially preprocessed by other transformations) and \(\alpha_t\) is the
learning_rateat time \(t\) (constant or selected by anoptax.Schedule).Stochastic gradient descent with momentum takes two possible forms.
\[\begin{align*} m_t &\leftarrow g_t + \mu m_{t-1} \\ u_t &\leftarrow \begin{cases} -\alpha_t m_t & \text{ if } \texttt{nesterov = False} \\ -\alpha_t (g_t + \mu m_t) & \text{ if } \texttt{nesterov = True} \end{cases} \\ S_t &\leftarrow m_t, \end{align*}\]where \(\mu\) is the
momentumparameter and \(S_t\) is the state of the optimizer.- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().momentum โ Decay rate used by the momentum term, when it is set to
None, then momentum is not used at all.nesterov โ Whether Nesterov momentum is used.
accumulator_dtype โ Optional
dtypeto be used for the accumulator; ifNonethen thedtypeis inferred fromparamsandupdates.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.sgd(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Sutskever et al, On the importance of initialization and momentum in deep learning, 2013