optax.keep_params_nonnegative

optax.keep_params_nonnegative#

optax.keep_params_nonnegative() optax.GradientTransformation[source]#

Modifies the updates to keep parameters non-negative, i.e. >= 0.

This transformation ensures that parameters after the update will be larger than or equal to zero. In a chain of transformations, this should be the last one.

Returns:

A optax.GradientTransformation object.

Warning

The transformation expects input params to be non-negative. When params is negative the transformed update will move them to 0.