optax.partition#
- optax.partition(transforms: Mapping[Hashable, base.GradientTransformation], param_labels: base.PyTree | Callable[[base.PyTree], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#
Partitions params and applies a different transformation to each subset.
Sometimes you may want to apply different transformations to different parameters. For example, you may want to apply Adam to the weights of a neural network, but SGD to the biases. This function allows you to do that.
- Parameters:
transforms โ A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.
param_labels โ A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).
mask_compatible_extra_args โ Whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.
- Returns:
A
optax.GradientTransformationExtraArgs()that implements aninitandupdatefunction.
Examples
Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def map_nested_fn(fn): ... '''Recursively apply `fn` to key-value pairs of a nested dict.''' ... def map_fn(nested_dict): ... return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) ... for k, v in nested_dict.items()} ... return map_fn >>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, ... 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} >>> gradients = jax.tree.map(jnp.ones_like, params) # dummy gradients >>> label_fn = map_nested_fn(lambda k, _: k) >>> tx = optax.partition( ... {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn) >>> state = tx.init(params) >>> updates, new_state = tx.update(gradients, state, params) >>> new_params = optax.apply_updates(params, updates)
Instead of providing a
label_fn, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:>>> generator_params = ... >>> discriminator_params = ... >>> all_params = (generator_params, discriminator_params) >>> param_labels = ('generator', 'discriminator') >>> tx = optax.partition( >>> {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, >>> param_labels)
If you would like to not optimize some parameters, you may wrap
optax.partition()withoptax.masked().