optax.freeze

Contents

optax.freeze#

optax.freeze(mask: bool | TypeAliasForwardRef('optax.ArrayTree')) optax.GradientTransformation[source]#

Create a transformation that zeros out gradient updates for mask=True.

This essentially freezes (i.e. holding constant) masked parameters.

The mask must be static (i.e., not dependent on runtime values or updated during training) and can be:

  • a single boolean (or 0-d JAX bool array), causing every parameter to be either all-frozen (True) or all-trainable (False), or

  • a PyTree of booleans matching the structure of the parameters, where each leaf indicates whether that specific parameter leaf should be frozen (True) or left unchanged (False).

Parameters:

mask โ€“ A boolean prefix tree mask indicating which parameters to freeze.

Example

>>> import jax.numpy as jnp
>>> from optax import freeze
>>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)}
>>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b'
>>> freezer = freeze(mask)
Returns:

An Optax GradientTransformation which applies set_to_zero() wherever mask==True, and leaves other gradients intact.

See also

optax.selective_transform() : For partitioning updates so only un-frozen parameters are optimized.