optax.named_chain

Contents

optax.named_chain#

optax.named_chain(*args: tuple[str, base.GradientTransformation]) base.GradientTransformationExtraArgs[source]#

Applies a list of named chainable update transformations.

A variant of optax.chain() that allows to name each transformation.

Here the args are (name, transformation) pairs, constituted of a string name and an associated transformation transformation. The gradient transformation must be an instance of GradientTransformation or GradientTransformationExtraArgs.

Each name is used as key for the state of the corresponding transformation within the named_chain state. Thus the state of the transformation with a given name can be easily retrieved as opt_state[name].

Parameters:

*args โ€“ an arbitrary number of (name, transform) pairs, constituted of a string name and an associated transformation transform. The latter is a GradientTransformation or GradientTransformationExtraArgs.

Returns:

A single (init_fn, update_fn) tuple.

Examples

>>> import optax
>>> opt1 = optax.scale(0.1)    # scale incoming gradients
>>> opt2 = optax.polyak_sgd()  # requires a `value` extra arg for `update`
>>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2))
>>> state = chained_transform.init(0.5)
>>> extra_args = {"value": 1.0}
>>> updates, new_state = chained_transform.update(
...     0.7, state, 0.7, **extra_args  # extra args for all transforms
... )
>>> tuple(new_state.keys()) == ("scale", "sgd")
True