optax.add_noise#
- optax.add_noise(eta: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, key: Array | int | None = None, *, seed: int | None = None) optax.GradientTransformation[source]#
Add gradient noise.
- Parameters:
eta โ Base variance of the gaussian noise added to the gradient.
gamma โ Decay exponent for annealing of the variance.
key โ random generator key for noise generation.
seed โ deprecated, use key instead.
- Returns:
A
optax.GradientTransformationobject.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> key = jax.random.key(0) # could also be key=0 >>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key) >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) >>> solver = optax.chain(noise, sgd) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Neelakantan et al, Adding Gradient Noise Improves Learning for Very Deep Networks, 2015