optax.noisy_sgd#
- optax.noisy_sgd(learning_rate: base.ScalarOrSchedule, eta: jax.typing.ArrayLike = 0.01, gamma: jax.typing.ArrayLike = 0.55, key: jax.typing.ArrayLike | None = None, *, seed: int | None = None) base.GradientTransformationExtraArgs[source]#
A variant of SGD with added noise.
Noisy SGD is a variant of
optax.sgd()that incorporates Gaussian noise into the updates. It has been found that adding noise to the gradients can improve both the training error and the generalization error in very deep networks.The update \(u_t\) is modified to include this noise as follows:
\[u_t \leftarrow -\alpha_t (g_t + N(0, \sigma_t^2)), \]where \(N(0, \sigma_t^2)\) represents Gaussian noise with zero mean and a variance of \(\sigma_t^2\).
The variance of this noise decays over time according to the formula
\[\sigma_t^2 = \frac{\eta}{(1+t)^\gamma}, \]where \(\gamma\) is the decay rate parameter
gammaand \(\eta\) represents the initial varianceeta.- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().eta โ Initial variance for the Gaussian noise added to gradients.
gamma โ A parameter controlling the annealing of noise over time
t, the variance decays according to(1+t)**(-gamma).key โ random generator key for noise generation.
seed โ deprecated, use key instead.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.noisy_sgd(learning_rate=0.003, key=0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Neelakantan et al, Adding Gradient Noise Improves Learning for Very Deep Networks, 2015