Combining Optimizers#


Applies a list of chainable update transformations.

multi_transform(transforms,ย param_labels,ย *)

Partitions params and applies a different transformation to each subset.



Applies a list of chainable update transformations.

This function creates a new optax.GradientTransformation() that applies a sequence of gradient transformations in order. The init function of the new transformation constructs the optimizer state by concatenating the states of the individual transforms, while the update function applies the updates in the given order.


A transform that scales by -0.1 the adam update:

>>> import optax
>>> transform1 = optax.scale_by_adam()
>>> transform2 = optax.scale(-0.1)
>>> chained_transform = optax.chain(transform1, transform2)
>>> params = {'a': 1.0}
>>> state = chained_transform.init(params)
>>> updates = {'a': -0.5}
>>> updates, new_state = chained_transform.update(updates, state, params)

*args (GradientTransformation) โ€“ a sequence of chainable (init_fn, update_fn) tuples.

Return type:



A GradientTransformationExtraArgs(), created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.


optax.multi_transform(transforms, param_labels, *, mask_compatible_extra_args=False)[source]#

Partitions params and applies a different transformation to each subset.

Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:

import optax
import jax
import jax.numpy as jnp

def map_nested_fn(fn):
  '''Recursively apply `fn` to the key-value pairs of a nested dict.'''
  def map_fn(nested_dict):
    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
            for k, v in nested_dict.items()}
  return map_fn

params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
          'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
gradients = jax.tree_util.tree_map(jnp.ones_like, params)  # dummy gradients

label_fn = map_nested_fn(lambda k, _: k)
tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)},
state = tx.init(params)
updates, new_state = tx.update(gradients, state, params)
new_params = optax.apply_updates(params, updates)

Instead of providing a label_fn, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:

generator_params = ...
discriminator_params = ...
all_params = (generator_params, discriminator_params)
param_labels = ('generator', 'discriminator')

tx = optax.multi_transform(
    {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)},

If you would like to not optimize some parameters, you may wrap optax.multi_transform() with optax.masked().

  • transforms (Mapping[Hashable, GradientTransformation]) โ€“ A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.

  • param_labels (Union[Any, Callable[[Any], Any]]) โ€“ A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).

  • mask_compatible_extra_args (bool) โ€“ Whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Return type:



A optax.GradientTransformationExtraArgs() that implements an init and update function.

class optax.MultiTransformState(inner_states)[source]#
inner_states: Mapping[Hashable, base.OptState]#

Alias for field number 0


Return self as a plain tuple. Used by copy and pickle.