optax.chain

Contents

optax.chain#

optax.chain(*args: base.GradientTransformation) base.GradientTransformationExtraArgs[source]#

Applies a list of chainable update transformations.

This function creates a new optax.GradientTransformation() that applies a sequence of gradient transformations in order. The init function of the new transformation constructs the optimizer state by concatenating the states of the individual transforms, while the update function applies the updates in the given order.

Parameters:

*args โ€“ an arbitrary number of transform-s of GradientTransformation or GradientTransformationExtraArgs.

Returns:

A GradientTransformationExtraArgs, created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.

Examples

A transform that scales by -0.1 the adam update:

>>> import optax
>>> transform1 = optax.scale_by_adam()
>>> transform2 = optax.scale(-0.1)
>>> chained_transform = optax.chain(transform1, transform2)
>>> params = {'a': 1.0}
>>> state = chained_transform.init(params)
>>> updates = {'a': -0.5}
>>> updates, new_state = chained_transform.update(updates, state, params)

An optimizer in the chain might require extra args:

>>> import optax
>>> opt1 = optax.scale(0.1)    # scale incoming gradients
>>> opt2 = optax.polyak_sgd()  # requires a `value` extra arg for `update`
>>> chained_transform = optax.chain(opt1, opt2)
>>> state = chained_transform.init(0.5)
>>> extra_args = {"value": 1.0}
>>> updates, new_state = chained_transform.update(
...     0.7, state, 0.7, **extra_args  # extra args for all transforms
... )