optax.masked

Contents

optax.masked#

optax.masked(inner: base.GradientTransformation, mask: base.PyTree | Callable[[base.Params], base.PyTree], *, mask_compatible_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Mask updates so only some are transformed, the rest are passed through.

For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. Since in many networks, these are the only 1D parameters, you may for instance create a mask function to mask them out as follows:

mask_fn = lambda p: jax.tree.map(lambda x: x.ndim != 1, p)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)

You may alternatively create the mask pytree upfront:

mask = jax.tree.map(lambda x: x.ndim != 1, params)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)

For the inner transform, state will only be stored for the parameters that have a mask value of True.

Note that, when using tree_map_params, it may be required to pass the argument is_leaf=lambda v: isinstance(v, optax.MaskedNode), if the tree map needs to take additional arguments with the same shape as the original input tree.

Parameters:
  • inner โ€“ Inner transformation to mask.

  • mask โ€“ a PyTree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip. The mask must be static for the gradient transformation to be jit-compilable.

  • mask_compatible_extra_args โ€“ whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.

Returns:

New optax.GradientTransformationExtraArgs wrapping inner.