optax.scale_by_sign#
- optax.scale_by_sign() optax.GradientTransformation[source]#
Compute the signs of the gradient elements.
- Returns:
An optax.GradientTransformation that contains the signs of the input gradient.
Compute the signs of the gradient elements.
An optax.GradientTransformation that contains the signs of the input gradient.