optax.snapshot

Contents

optax.snapshot#

optax.snapshot(measure_name: str, measure: Callable[[TypeAliasForwardRef('optax.Updates')], TypeAliasForwardRef('optax.ArrayTree')]) optax.GradientTransformation[source]#

Takes a snapshot of updates and stores it in the state.

Useful to debug intermediate updates values in a chained transformation.

Parameters:
  • measure_name โ€“ Name of the measurement to store. Can be then used to retrieve the snapshot using optax.tree.get(state, measure_name).

  • measure โ€“ User callable taking as inputs updates and returning desired measurement. When this transformation is part of a chain, the updates are the transformed gradients up to that transform.

Returns:

A gradient transformation that captures measurements defined by the user in the callable measure and stores them in the state with the name measure_name.

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)
>>> solver = optax.chain(
...     optax.sgd(learning_rate=0.1, momentum=0.9),
...     optax.snapshot('norm_before_clip', lambda x: optax.tree.norm(x)),
...     optax.clip_by_global_norm(0.05)
... )
>>> params = jnp.array([1., 2., 3.])
>>> state = solver.init(params)
>>> for step in range(2):
...   grads = jax.grad(f)(params)
...   updates, state = solver.update(grads, state)
...   params = optax.apply_updates(params, updates)
...   norm = optax.tree.get(state, 'norm_before_clip')
...   print(f'{step=}, {norm=:.2e}')
step=0, norm=7.48e-01
step=1, norm=1.41e+00