Combining Optimizers#
|
Applies a list of chainable update transformations. |
|
Applies a list of named chainable update transformations. |
|
Partitions params and applies a different transformation to each subset. |
Chain#
- optax.chain(*args: base.GradientTransformation) base.GradientTransformationExtraArgs [source]#
Applies a list of chainable update transformations.
This function creates a new
optax.GradientTransformation()
that applies a sequence of gradient transformations in order. Theinit
function of the new transformation constructs the optimizer state by concatenating the states of the individual transforms, while theupdate
function applies the updates in the given order.- Parameters:
*args โ an arbitrary number of
transform
-s ofGradientTransformation
orGradientTransformationExtraArgs
.- Returns:
A
GradientTransformationExtraArgs
, created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.
Examples
A transform that scales by -0.1 the adam update:
>>> import optax >>> transform1 = optax.scale_by_adam() >>> transform2 = optax.scale(-0.1) >>> chained_transform = optax.chain(transform1, transform2) >>> params = {'a': 1.0} >>> state = chained_transform.init(params) >>> updates = {'a': -0.5} >>> updates, new_state = chained_transform.update(updates, state, params)
An optimizer in the chain might require extra args:
>>> import optax >>> opt1 = optax.scale(0.1) # scale incoming gradients >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` >>> chained_transform = optax.chain(opt1, opt2) >>> state = chained_transform.init(0.5) >>> extra_args = {"value": 1.0} >>> updates, new_state = chained_transform.update( ... 0.7, state, 0.7, **extra_args # extra args for all transforms ... )
- optax.named_chain(*args: tuple[str, base.GradientTransformation]) base.GradientTransformationExtraArgs [source]#
Applies a list of named chainable update transformations.
A variant of
optax.chain()
that allows to name each transformation.Here the
args
are(name, transformation)
pairs, constituted of a stringname
and an associated transformationtransformation
. The gradient transformation must be an instance ofGradientTransformation
orGradientTransformationExtraArgs
.Each
name
is used as key for the state of the corresponding transformation within thenamed_chain
state. Thus the state of the transformation with a givenname
can be easily retrieved asopt_state[name]
.- Parameters:
*args โ an arbitrary number of
(name, transform)
pairs, constituted of a stringname
and an associated transformationtransform
. The latter is aGradientTransformation
orGradientTransformationExtraArgs
.- Returns:
A single (init_fn, update_fn) tuple.
Examples
>>> import optax >>> opt1 = optax.scale(0.1) # scale incoming gradients >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` >>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2)) >>> state = chained_transform.init(0.5) >>> extra_args = {"value": 1.0} >>> updates, new_state = chained_transform.update( ... 0.7, state, 0.7, **extra_args # extra args for all transforms ... ) >>> tuple(new_state.keys()) == ("scale", "sgd") True
Partition#
- optax.partition(transforms: Mapping[Hashable, base.GradientTransformation], param_labels: base.PyTree | Callable[[base.PyTree], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs [source]#
Partitions params and applies a different transformation to each subset.
Sometimes you may want to apply different transformations to different parameters. For example, you may want to apply Adam to the weights of a neural network, but SGD to the biases. This function allows you to do that.
- Parameters:
transforms โ A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.
param_labels โ A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).
mask_compatible_extra_args โ Whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.
- Returns:
A
optax.GradientTransformationExtraArgs()
that implements aninit
andupdate
function.
Examples
Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def map_nested_fn(fn): ... '''Recursively apply `fn` to key-value pairs of a nested dict.''' ... def map_fn(nested_dict): ... return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) ... for k, v in nested_dict.items()} ... return map_fn >>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, ... 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} >>> gradients = jax.tree.map(jnp.ones_like, params) # dummy gradients >>> label_fn = map_nested_fn(lambda k, _: k) >>> tx = optax.partition( ... {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn) >>> state = tx.init(params) >>> updates, new_state = tx.update(gradients, state, params) >>> new_params = optax.apply_updates(params, updates)
Instead of providing a
label_fn
, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:>>> generator_params = ... >>> discriminator_params = ... >>> all_params = (generator_params, discriminator_params) >>> param_labels = ('generator', 'discriminator') >>> tx = optax.partition( >>> {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, >>> param_labels)
If you would like to not optimize some parameters, you may wrap
optax.partition()
withoptax.masked()
.