optax.selective_transform#
- optax.selective_transform(optimizer: optax.GradientTransformation, *, freeze_mask: bool | TypeAliasForwardRef('optax.ArrayTree')) optax.GradientTransformation[source]#
Partition updates so that only un-frozen parameters are optimized.
Example
>>> import jax.numpy as jnp >>> from optax import selective_transform >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)} >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b' >>> selective_opt = selective_transform(optax.adam(1e-3), freeze_mask=mask)
- Parameters:
optimizer โ The inner Optax optimizer to apply to unfrozen leaves.
freeze_mask โ A static mask (i.e., not dependent on runtime values or
either (updated during training). It can be) โ
a scalar bool (or 0-d JAX bool array) to freeze everything (True) or nothing (False)
a PyTree of booleans mirroring the parameter tree, marking each leaf to freeze (True) or train (False).
- Returns:
the given optimizer if its mask is False (โtrainโ),
set_to_zero() if its mask is True (โfreezeโ).
- Return type:
A GradientTransformation that routes each parameter leaf through
See also
optax.freeze(): For simply zeroing out gradients according to a mask.