optax.conditionally_mask

optax.conditionally_mask#

optax.conditionally_mask(inner: base.GradientTransformation, should_transform_fn: ConditionFn, forward_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Calls the inner update function only at certain steps.

Creates a transformation wrapper that conditionally applies the inner gradient transformation, and if the condition is not met, the updates are set to 0, while the inner state is passed through unchanged. The behavior is controlled by a user specified function should_transform_fn that is called by conditionally_transform passing as input a counter of the number of times that the update function has been previously called, the user specified function must returns a boolean controlling whether the inner transformation should be called.

Parameters:
  • inner โ€“ the inner transformation.

  • should_transform_fn โ€“ function takes in a step counter (array of shape [] and dtype int32), and returns a boolean array of shape []. If forward_extra_args is set to True, any extra arguments are also forwarded to the should_transform_fn.

  • forward_extra_args โ€“ forward extra args to should_transform_fn.

Returns:

A new optax.GradientTransformationExtraArgs.

Warning

If instead you want to leave updates unchanged when the condition is not met, you can use the conditionally_transform wrapper.

Added in version 0.2.3.