optax.monitor#
- optax.monitor(measures: dict[str, base.GradientTransformationExtraArgs | Callable[[base.Updates], base.ArrayTree]])[source]#
Monitors stateful measurements of updates in a chain.
Extends func::optax.snapshot to use stateful measurements, such as using exponential moving average.
- Parameters:
measures โ A dictionary of measurement names to gradient transformations capturing them.
- Returns:
A gradient transformation that captures measurements defined by the user.
Examples
>>> import optax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) >>> clip_thresh = 1.0 >>> solver = optax.chain( ... optax.sgd(learning_rate=0.1, momentum=0.9), ... optax.monitor({ ... 'norm_before_clip': optax.tree.norm, ... 'is_clipped_ema': optax.measure_with_ema( ... lambda x: optax.tree.norm(x) > clip_thresh, ... decay=0.9, ... ) ... }), ... optax.clip_by_global_norm(clip_thresh), ... ) >>> params = jnp.array([1., 2., 3.]) >>> state = solver.init(params) >>> for step in range(2): ... grads = jax.grad(f)(params) ... updates, state = solver.update(grads, state) ... params = optax.apply_updates(params, updates) ... norm_before_clip = optax.tree.get(state, 'norm_before_clip') ... is_clipped_ema = optax.tree.get(state, 'is_clipped_ema') ... print(f'{step=}, {norm_before_clip=:.2e}, {is_clipped_ema=:.2e}') step=0, norm_before_clip=7.48e-01, is_clipped_ema=0.00e+00 step=1, norm_before_clip=1.27e+00, is_clipped_ema=5.26e-01