optax.normalize_by_update_norm#
- optax.normalize_by_update_norm(scale_factor: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 1e-06) optax.GradientTransformation[source]#
Scale by the inverse of the update norm.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.normalize_by_update_norm(scale_factor=-1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function:', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 7.52E+00 Objective function: 3.03E+00 Objective function: 5.50E-01 Objective function: 6.67E-02 Objective function: 5.50E-01
- Parameters:
scale_factor โ factor by which the update will be multiplied (defaults to 1).
eps โ jitter term to avoid dividing by 0
- Returns:
A
optax.GradientTransformationobject.