optax.scale_by_sign

Contents

optax.scale_by_sign#

optax.scale_by_sign() optax.GradientTransformation[source]#

Compute the signs of the gradient elements.

Returns:

An optax.GradientTransformation that contains the signs of the input gradient.