optax.freeze#
- optax.freeze(mask: bool | TypeAliasForwardRef('optax.ArrayTree')) optax.GradientTransformation[source]#
Create a transformation that zeros out gradient updates for mask=True.
This essentially freezes (i.e. holding constant) masked parameters.
The mask must be static (i.e., not dependent on runtime values or updated during training) and can be:
a single boolean (or 0-d JAX bool array), causing every parameter to be either all-frozen (True) or all-trainable (False), or
a PyTree of booleans matching the structure of the parameters, where each leaf indicates whether that specific parameter leaf should be frozen (True) or left unchanged (False).
- Parameters:
mask โ A boolean prefix tree mask indicating which parameters to freeze.
Example
>>> import jax.numpy as jnp >>> from optax import freeze >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)} >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b' >>> freezer = freeze(mask)
- Returns:
An Optax GradientTransformation which applies set_to_zero() wherever mask==True, and leaves other gradients intact.
See also
optax.selective_transform(): For partitioning updates so only un-frozen parameters are optimized.