optax.scale_by_novograd#
- optax.scale_by_novograd(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.25, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, mu_dtype: str | type[Any] | dtype | SupportsDType | None = None) optax.GradientTransformation[source]#
Computes NovoGrad updates.
See
optax.novograd()for more details.- Parameters:
b1 โ A decay rate for the exponentially weighted average of grads.
b2 โ A decay rate for the exponentially weighted average of squared grads.
eps โ A term added to the denominator to improve numerical stability.
eps_root โ A term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
weight_decay โ A scalar weight decay rate.
mu_dtype โ An optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
- Returns:
The corresponding
optax.GradientTransformation.