optax.conditionally_mask#
- optax.conditionally_mask(inner: base.GradientTransformation, should_transform_fn: ConditionFn, forward_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#
Calls the inner update function only at certain steps.
Creates a transformation wrapper that conditionally applies the inner gradient transformation, and if the condition is not met, the updates are set to 0, while the inner state is passed through unchanged. The behavior is controlled by a user specified function
should_transform_fnthat is called byconditionally_transformpassing as input a counter of the number of times that theupdatefunction has been previously called, the user specified function must returns a boolean controlling whether the inner transformation should be called.- Parameters:
inner โ the inner transformation.
should_transform_fn โ function takes in a step counter (array of shape [] and dtype
int32), and returns a boolean array of shape []. Ifforward_extra_argsis set to True, any extra arguments are also forwarded to theshould_transform_fn.forward_extra_args โ forward extra args to
should_transform_fn.
- Returns:
Warning
If instead you want to leave
updatesunchanged when the condition is not met, you can use theconditionally_transformwrapper.Added in version 0.2.3.