optax.contrib.scale_by_acprop

optax.contrib.scale_by_acprop#

optax.contrib.scale_by_acprop(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16) optax.GradientTransformation[source]#

Rescale updates according to ACProp (asynchronous version of AdaBelief).

See optax.contrib.acprop() for more details.

Parameters:
  • b1 โ€“ Decay rate for the exponentially weighted average of grads.

  • b2 โ€“ Decay rate for the exponentially weighted average of variance of grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • eps_root โ€“ Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

Returns:

A GradientTransformation object.