optax.add_noise

Contents

optax.add_noise#

optax.add_noise(eta: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, key: Array | int | None = None, *, seed: int | None = None) optax.GradientTransformation[source]#

Add gradient noise.

Parameters:
  • eta โ€“ Base variance of the gaussian noise added to the gradient.

  • gamma โ€“ Decay exponent for annealing of the variance.

  • key โ€“ random generator key for noise generation.

  • seed โ€“ deprecated, use key instead.

Returns:

A optax.GradientTransformation object.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> key = jax.random.key(0)  # could also be key=0
>>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key)
>>> sgd = optax.scale_by_learning_rate(learning_rate=0.003)
>>> solver = optax.chain(noise, sgd)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References

Neelakantan et al, Adding Gradient Noise Improves Learning for Very Deep Networks, 2015