optax.normalize_by_update_norm

optax.normalize_by_update_norm#

optax.normalize_by_update_norm(scale_factor: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 1e-06) optax.GradientTransformation[source]#

Scale by the inverse of the update norm.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.normalize_by_update_norm(scale_factor=-1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function:', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 7.52E+00
Objective function: 3.03E+00
Objective function: 5.50E-01
Objective function: 6.67E-02
Objective function: 5.50E-01
Parameters:
  • scale_factor โ€“ factor by which the update will be multiplied (defaults to 1).

  • eps โ€“ jitter term to avoid dividing by 0

Returns:

A optax.GradientTransformation object.