optax.sm3#
- optax.sm3(learning_rate: jax.typing.ArrayLike, momentum: jax.typing.ArrayLike = 0.9) base.GradientTransformationExtraArgs[source]#
The SM3 optimizer.
SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a memory-efficient adaptive optimizer designed to decrease memory overhead when training very large models, such as the Transformer for machine translation, BERT for language modeling, and AmoebaNet-D for image classification. SM3: 1) applies to tensors of arbitrary dimensions and any predefined cover of the parameters; 2) adapts the learning rates in an adaptive and data-driven manner (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence guarantees in stochastic convex optimization settings.
The init function of this optimizer initializes an internal state \(S_0 := \{\mu_0, w_1\} = \{0, 0\}\), representing initial estimates for the cumulative squared gradients and the weights. These values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have:
SM3-I Algorithm
\[\begin{array}{l} \text{parameters: learning rate } \eta \\ \text{initialize } w_1 = 0; \forall r \in [k]: \mu_0(r) = 0 \\ \text{for } t = 1, \ldots, T \text{ do} \\ \quad \text{receive gradient } g_t = \nabla \ell_t(w_t) \\ \quad \text{for } r = 1, \ldots, k \text{ do} \\ \quad \quad \mu_t(r) \leftarrow \mu_{t-1}(r) + \max_{j \in S_r} g_t^2(j) \\ \quad \text{for } i = 1, \ldots, d \text{ do} \\ \quad \quad \nu_t(i) \leftarrow \min_{r:S_r \ni i} \mu_t(r) \\ \quad \quad w_{t+1}(i) \leftarrow w_t(i) - \eta \frac{g_t(i)}{\sqrt{\nu_t(i)}} \\ \quad \quad \text{with the convention that } 0/0 = 0 \end{array}\]SM3-II Algorithm
The SM3-II optimizer initializes with parameters like the learning rate :math:eta and weight :math:w_1. It updates weights iteratively using gradients :math:g_t, adjusting each component with minimum accumulated values :math:nu’_t(i) and maintaining cumulative maximums :math:mu’_t(r) for subsets :math:S_r. SM3-II starts with an initial state :math:S_0 := (m_0, s_0) set to zero, storing estimates for first and second moments as pytrees matching model updates’ shape
\[\begin{array}{l} \text{parameters: learning rate } \eta \\ \text{initialize } w_1 = 0; \forall r \in [k]: \mu'_0(r) = 0 \\ \text{for } t = 1, \ldots, T \text{ do} \\ \quad \text{receive gradient } g_t = \nabla \ell_t(w_t) \\ \quad \text{initialize } \mu'_t(r) = 0 \text{ for all } r \in [k] \\ \quad \text{for } i = 1, \ldots, d \text{ do} \\ \quad \quad \nu'_t(i) \leftarrow \min_{r:S_r \ni i} \mu'_{t-1}(r) + g_t^2(i) \\ \quad \quad w_{t+1}(i) \leftarrow w_t(i) - \eta \frac{g_t(i)}{\sqrt{\nu'_t(i)}} \\ \quad \quad \text{with the convention that } 0/0 = 0 \\ \quad \text{for all } r : S_r \ni i \text{ do} \\ \quad \quad \mu'_t(r) \leftarrow \max\{\mu'_t(r), \nu'_t(i)\} \end{array}\]- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().momentum – Decay rate used by the momentum term (when it is not set to None, then momentum is not used at all).
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.sm3(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01
References
Anil et al, Memory-Efficient Adaptive Optimization, 2019