optax.named_chain#
- optax.named_chain(*args: tuple[str, base.GradientTransformation]) base.GradientTransformationExtraArgs[source]#
Applies a list of named chainable update transformations.
A variant of
optax.chain()that allows to name each transformation.Here the
argsare(name, transformation)pairs, constituted of a stringnameand an associated transformationtransformation. The gradient transformation must be an instance ofGradientTransformationorGradientTransformationExtraArgs.Each
nameis used as key for the state of the corresponding transformation within thenamed_chainstate. Thus the state of the transformation with a givennamecan be easily retrieved asopt_state[name].- Parameters:
*args โ an arbitrary number of
(name, transform)pairs, constituted of a stringnameand an associated transformationtransform. The latter is aGradientTransformationorGradientTransformationExtraArgs.- Returns:
A single (init_fn, update_fn) tuple.
Examples
>>> import optax >>> opt1 = optax.scale(0.1) # scale incoming gradients >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` >>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2)) >>> state = chained_transform.init(0.5) >>> extra_args = {"value": 1.0} >>> updates, new_state = chained_transform.update( ... 0.7, state, 0.7, **extra_args # extra args for all transforms ... ) >>> tuple(new_state.keys()) == ("scale", "sgd") True