optax.scale_gradient#
- optax.scale_gradient(inputs: optax.ArrayTree, scale: jax.typing.ArrayLike) optax.ArrayTree[source]#
Scales gradients for the backwards pass.
- Parameters:
inputs โ A nested array.
scale โ The scale factor for the gradient on the backwards pass.
- Returns:
An array of the same structure as inputs, with scaled backward gradient.