optax.partition

Contents

optax.partition#

optax.partition(transforms: Mapping[Hashable, base.GradientTransformation], param_labels: base.PyTree | Callable[[base.PyTree], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Partitions params and applies a different transformation to each subset.

Sometimes you may want to apply different transformations to different parameters. For example, you may want to apply Adam to the weights of a neural network, but SGD to the biases. This function allows you to do that.

Parameters:
  • transforms โ€“ A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.

  • param_labels โ€“ A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).

  • mask_compatible_extra_args โ€“ Whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Returns:

A optax.GradientTransformationExtraArgs() that implements an init and update function.

Examples

Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp

>>> def map_nested_fn(fn):
...   '''Recursively apply `fn` to key-value pairs of a nested dict.'''
...   def map_fn(nested_dict):
...     return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
...             for k, v in nested_dict.items()}
...   return map_fn

>>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
...           'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
>>> gradients = jax.tree.map(jnp.ones_like, params)  # dummy gradients

>>> label_fn = map_nested_fn(lambda k, _: k)
>>> tx = optax.partition(
...     {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn)
>>> state = tx.init(params)
>>> updates, new_state = tx.update(gradients, state, params)
>>> new_params = optax.apply_updates(params, updates)

Instead of providing a label_fn, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:

>>> generator_params = ...
>>> discriminator_params = ...
>>> all_params = (generator_params, discriminator_params)
>>> param_labels = ('generator', 'discriminator')

>>> tx = optax.partition(
>>>     {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)},
>>>     param_labels)

If you would like to not optimize some parameters, you may wrap optax.partition() with optax.masked().