optax.monitor

Contents

optax.monitor#

optax.monitor(measures: dict[str, base.GradientTransformationExtraArgs | Callable[[base.Updates], base.ArrayTree]])[source]#

Monitors stateful measurements of updates in a chain.

Extends func::optax.snapshot to use stateful measurements, such as using exponential moving average.

Parameters:

measures โ€“ A dictionary of measurement names to gradient transformations capturing them.

Returns:

A gradient transformation that captures measurements defined by the user.

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)
>>> clip_thresh = 1.0
>>> solver = optax.chain(
...     optax.sgd(learning_rate=0.1, momentum=0.9),
...     optax.monitor({
...         'norm_before_clip': optax.tree.norm,
...         'is_clipped_ema': optax.measure_with_ema(
...             lambda x: optax.tree.norm(x) > clip_thresh,
...             decay=0.9,
...         )
...     }),
...     optax.clip_by_global_norm(clip_thresh),
... )
>>> params = jnp.array([1., 2., 3.])
>>> state = solver.init(params)
>>> for step in range(2):
...   grads = jax.grad(f)(params)
...   updates, state = solver.update(grads, state)
...   params = optax.apply_updates(params, updates)
...   norm_before_clip = optax.tree.get(state, 'norm_before_clip')
...   is_clipped_ema = optax.tree.get(state, 'is_clipped_ema')
...   print(f'{step=}, {norm_before_clip=:.2e}, {is_clipped_ema=:.2e}')
step=0, norm_before_clip=7.48e-01, is_clipped_ema=0.00e+00
step=1, norm_before_clip=1.27e+00, is_clipped_ema=5.26e-01