optax.conditionally_transform#
- optax.conditionally_transform(inner: base.GradientTransformation, should_transform_fn: ConditionFn, forward_extra_args: bool = False) base.GradientTransformationExtraArgs[source]#
Calls the inner update function only at certain steps.
Creates a transformation wrapper that conditionally applies the inner gradient transformation, and if the condition is not met, just passes the updates and inner state through unchanged. The behavior is controlled by a user specified function
should_transform_fnthat is called byconditionally_transformpassing as input a counter of the number of times that theupdatefunction has been previously called, the user specified function must returns a boolean controlling whether the inner transformation should be called.- Parameters:
inner โ the inner transformation.
should_transform_fn โ function takes in a
stepcounter (array of shape [] and dtypeint32), and returns a boolean array of shape []. Ifforward_extra_argsis set to True, any extra arguments are also forwarded to theshould_transform_fn.forward_extra_args โ forward extra args to
should_transform_fn.
- Returns:
Warning
If instead you want to set the
updatesto zero when the condition is not met, you can use theconditionally_maskwrapper.Added in version 0.2.3.