Combining Optimizers#

chain(*args)

Applies a list of chainable update transformations.

multi_transform(transforms,ย param_labels,ย *)

Partitions params and applies a different transformation to each subset.

Chain#

optax.chain(*args: base.GradientTransformation) base.GradientTransformationExtraArgs[source]#

Applies a list of chainable update transformations.

This function creates a new optax.GradientTransformation() that applies a sequence of gradient transformations in order. The init function of the new transformation constructs the optimizer state by concatenating the states of the individual transforms, while the update function applies the updates in the given order.

Examples

A transform that scales by -0.1 the adam update:

>>> import optax
>>> transform1 = optax.scale_by_adam()
>>> transform2 = optax.scale(-0.1)
>>> chained_transform = optax.chain(transform1, transform2)
>>> params = {'a': 1.0}
>>> state = chained_transform.init(params)
>>> updates = {'a': -0.5}
>>> updates, new_state = chained_transform.update(updates, state, params)
Parameters:

*args โ€“ a sequence of chainable (init_fn, update_fn) tuples.

Returns:

A GradientTransformationExtraArgs(), created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.

Multi-transform#

optax.multi_transform(transforms: Mapping[Hashable, base.GradientTransformation], param_labels: base.PyTree | Callable[[base.PyTree], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Partitions params and applies a different transformation to each subset.

Sometimes you may want to apply different transformations to different parameters. For example, you may want to apply Adam to the weights of a neural network, but SGD to the biases. This function allows you to do that.

Examples

Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp

>>> def map_nested_fn(fn):
...   '''Recursively apply `fn` to key-value pairs of a nested dict.'''
...   def map_fn(nested_dict):
...     return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
...             for k, v in nested_dict.items()}
...   return map_fn

>>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
...           'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
>>> gradients = jax.tree.map(jnp.ones_like, params)  # dummy gradients

>>> label_fn = map_nested_fn(lambda k, _: k)
>>> tx = optax.multi_transform(
...     {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn)
>>> state = tx.init(params)
>>> updates, new_state = tx.update(gradients, state, params)
>>> new_params = optax.apply_updates(params, updates)

Instead of providing a label_fn, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:

>>> generator_params = ...
>>> discriminator_params = ...
>>> all_params = (generator_params, discriminator_params)
>>> param_labels = ('generator', 'discriminator')

>>> tx = optax.multi_transform(
>>>     {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)},
>>>     param_labels)

If you would like to not optimize some parameters, you may wrap optax.multi_transform() with optax.masked().

Parameters:
  • transforms โ€“ A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.

  • param_labels โ€“ A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).

  • mask_compatible_extra_args โ€“ Whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Returns:

A optax.GradientTransformationExtraArgs() that implements an init and update function.

optax.MultiTransformState[source]#

alias of PartitionState