Combining Optimizers#

chain(*args)

Applies a list of chainable update transformations.

named_chain(*args)

Applies a list of named chainable update transformations.

partition(transforms,ย param_labels,ย *[,ย ...])

Partitions params and applies a different transformation to each subset.

Chain#

optax.chain(*args: base.GradientTransformation) base.GradientTransformationExtraArgs[source]#

Applies a list of chainable update transformations.

This function creates a new optax.GradientTransformation() that applies a sequence of gradient transformations in order. The init function of the new transformation constructs the optimizer state by concatenating the states of the individual transforms, while the update function applies the updates in the given order.

Parameters:

*args โ€“ an arbitrary number of transform-s of GradientTransformation or GradientTransformationExtraArgs.

Returns:

A GradientTransformationExtraArgs, created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.

Examples

A transform that scales by -0.1 the adam update:

>>> import optax
>>> transform1 = optax.scale_by_adam()
>>> transform2 = optax.scale(-0.1)
>>> chained_transform = optax.chain(transform1, transform2)
>>> params = {'a': 1.0}
>>> state = chained_transform.init(params)
>>> updates = {'a': -0.5}
>>> updates, new_state = chained_transform.update(updates, state, params)

An optimizer in the chain might require extra args:

>>> import optax
>>> opt1 = optax.scale(0.1)    # scale incoming gradients
>>> opt2 = optax.polyak_sgd()  # requires a `value` extra arg for `update`
>>> chained_transform = optax.chain(opt1, opt2)
>>> state = chained_transform.init(0.5)
>>> extra_args = {"value": 1.0}
>>> updates, new_state = chained_transform.update(
...     0.7, state, 0.7, **extra_args  # extra args for all transforms
... )
optax.named_chain(*args: tuple[str, base.GradientTransformation]) base.GradientTransformationExtraArgs[source]#

Applies a list of named chainable update transformations.

A variant of optax.chain() that allows to name each transformation.

Here the args are (name, transformation) pairs, constituted of a string name and an associated transformation transformation. The gradient transformation must be an instance of GradientTransformation or GradientTransformationExtraArgs.

Each name is used as key for the state of the corresponding transformation within the named_chain state. Thus the state of the transformation with a given name can be easily retrieved as opt_state[name].

Parameters:

*args โ€“ an arbitrary number of (name, transform) pairs, constituted of a string name and an associated transformation transform. The latter is a GradientTransformation or GradientTransformationExtraArgs.

Returns:

A single (init_fn, update_fn) tuple.

Examples

>>> import optax
>>> opt1 = optax.scale(0.1)    # scale incoming gradients
>>> opt2 = optax.polyak_sgd()  # requires a `value` extra arg for `update`
>>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2))
>>> state = chained_transform.init(0.5)
>>> extra_args = {"value": 1.0}
>>> updates, new_state = chained_transform.update(
...     0.7, state, 0.7, **extra_args  # extra args for all transforms
... )
>>> tuple(new_state.keys()) == ("scale", "sgd")
True

Partition#

optax.partition(transforms: Mapping[Hashable, base.GradientTransformation], param_labels: base.PyTree | Callable[[base.PyTree], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Partitions params and applies a different transformation to each subset.

Sometimes you may want to apply different transformations to different parameters. For example, you may want to apply Adam to the weights of a neural network, but SGD to the biases. This function allows you to do that.

Parameters:
  • transforms โ€“ A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.

  • param_labels โ€“ A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).

  • mask_compatible_extra_args โ€“ Whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Returns:

A optax.GradientTransformationExtraArgs() that implements an init and update function.

Examples

Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp

>>> def map_nested_fn(fn):
...   '''Recursively apply `fn` to key-value pairs of a nested dict.'''
...   def map_fn(nested_dict):
...     return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
...             for k, v in nested_dict.items()}
...   return map_fn

>>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
...           'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
>>> gradients = jax.tree.map(jnp.ones_like, params)  # dummy gradients

>>> label_fn = map_nested_fn(lambda k, _: k)
>>> tx = optax.partition(
...     {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn)
>>> state = tx.init(params)
>>> updates, new_state = tx.update(gradients, state, params)
>>> new_params = optax.apply_updates(params, updates)

Instead of providing a label_fn, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:

>>> generator_params = ...
>>> discriminator_params = ...
>>> all_params = (generator_params, discriminator_params)
>>> param_labels = ('generator', 'discriminator')

>>> tx = optax.partition(
>>>     {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)},
>>>     param_labels)

If you would like to not optimize some parameters, you may wrap optax.partition() with optax.masked().

class optax.PartitionState(inner_states)[source]#