optax.chain#
- optax.chain(*args: base.GradientTransformation) base.GradientTransformationExtraArgs[source]#
Applies a list of chainable update transformations.
This function creates a new
optax.GradientTransformation()that applies a sequence of gradient transformations in order. Theinitfunction of the new transformation constructs the optimizer state by concatenating the states of the individual transforms, while theupdatefunction applies the updates in the given order.- Parameters:
*args โ an arbitrary number of
transform-s ofGradientTransformationorGradientTransformationExtraArgs.- Returns:
A
GradientTransformationExtraArgs, created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.
Examples
A transform that scales by -0.1 the adam update:
>>> import optax >>> transform1 = optax.scale_by_adam() >>> transform2 = optax.scale(-0.1) >>> chained_transform = optax.chain(transform1, transform2) >>> params = {'a': 1.0} >>> state = chained_transform.init(params) >>> updates = {'a': -0.5} >>> updates, new_state = chained_transform.update(updates, state, params)
An optimizer in the chain might require extra args:
>>> import optax >>> opt1 = optax.scale(0.1) # scale incoming gradients >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` >>> chained_transform = optax.chain(opt1, opt2) >>> state = chained_transform.init(0.5) >>> extra_args = {"value": 1.0} >>> updates, new_state = chained_transform.update( ... 0.7, state, 0.7, **extra_args # extra args for all transforms ... )