optax.scale#
- optax.scale(step_size: jax.typing.ArrayLike) optax.GradientTransformation[source]#
Scale updates by some fixed scalar step_size.
- Parameters:
step_size โ A scalar corresponding to a fixed scaling factor for updates.
- Returns:
A
optax.GradientTransformationobject.