optax.selective_transform

optax.selective_transform#

optax.selective_transform(optimizer: optax.GradientTransformation, *, freeze_mask: bool | TypeAliasForwardRef('optax.ArrayTree')) optax.GradientTransformation[source]#

Partition updates so that only un-frozen parameters are optimized.

Example

>>> import jax.numpy as jnp
>>> from optax import selective_transform
>>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)}
>>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b'
>>> selective_opt = selective_transform(optax.adam(1e-3), freeze_mask=mask)
Parameters:
  • optimizer โ€“ The inner Optax optimizer to apply to unfrozen leaves.

  • freeze_mask โ€“ A static mask (i.e., not dependent on runtime values or

  • either (updated during training). It can be) โ€“

    • a scalar bool (or 0-d JAX bool array) to freeze everything (True) or nothing (False)

    • a PyTree of booleans mirroring the parameter tree, marking each leaf to freeze (True) or train (False).

Returns:

  • the given optimizer if its mask is False (โ€œtrainโ€),

  • set_to_zero() if its mask is True (โ€œfreezeโ€).

Return type:

A GradientTransformation that routes each parameter leaf through

See also

optax.freeze() : For simply zeroing out gradients according to a mask.