optax.conditionally_transform

optax.conditionally_transform#

optax.conditionally_transform(inner: base.GradientTransformation, should_transform_fn: ConditionFn, forward_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#

Calls the inner update function only at certain steps.

Creates a transformation wrapper that conditionally applies the inner gradient transformation, and if the condition is not met, just passes the updates and inner state through unchanged. The behavior is controlled by a user specified function should_transform_fn that is called by conditionally_transform passing as input a counter of the number of times that the update function has been previously called, the user specified function must returns a boolean controlling whether the inner transformation should be called.

Parameters:
  • inner โ€“ the inner transformation.

  • should_transform_fn โ€“ function takes in a step counter (array of shape [] and dtype int32), and returns a boolean array of shape []. If forward_extra_args is set to True, any extra arguments are also forwarded to the should_transform_fn.

  • forward_extra_args โ€“ forward extra args to should_transform_fn.

Returns:

A new optax.GradientTransformationExtraArgs.

Warning

If instead you want to set the updates to zero when the condition is not met, you can use the conditionally_mask wrapper.

Added in version 0.2.3.