optax.scale_by_amsgrad

optax.scale_by_amsgrad#

optax.scale_by_amsgrad(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: str | type[Any] | dtype | SupportsDType | None = None, bias_correction_mu: bool = True, bias_correction_nu: bool = True) optax.GradientTransformation[source]#

Rescale updates according to the AMSGrad algorithm.

See optax.amsgrad() for more details.

Parameters:
  • b1 โ€“ Decay rate for the exponentially weighted average of grads.

  • b2 โ€“ Decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • eps_root โ€“ Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • mu_dtype โ€“ Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • bias_correction_mu โ€“ Whether to apply bias correction to the first moment estimate. Set to False to match the original AMSGrad paper.

  • bias_correction_nu โ€“ Whether to apply bias correction to the second moment estimate before taking the elementwise maximum (nu_max). Set to False to match the original AMSGrad paper.

Returns:

A optax.GradientTransformation object.