optax.scale_by_novograd

optax.scale_by_novograd#

optax.scale_by_novograd(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.25, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, mu_dtype: str | type[Any] | dtype | SupportsDType | None = None) optax.GradientTransformation[source]#

Computes NovoGrad updates.

See optax.novograd() for more details.

Parameters:
  • b1 โ€“ A decay rate for the exponentially weighted average of grads.

  • b2 โ€“ A decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ A term added to the denominator to improve numerical stability.

  • eps_root โ€“ A term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • weight_decay โ€“ A scalar weight decay rate.

  • mu_dtype โ€“ An optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding optax.GradientTransformation.