optax.scale_gradient

Contents

optax.scale_gradient#

optax.scale_gradient(inputs: optax.ArrayTree, scale: jax.typing.ArrayLike) optax.ArrayTree[source]#

Scales gradients for the backwards pass.

Parameters:
  • inputs โ€“ A nested array.

  • scale โ€“ The scale factor for the gradient on the backwards pass.

Returns:

An array of the same structure as inputs, with scaled backward gradient.