optax.contrib.scale_by_adopt

Contents

optax.contrib.scale_by_adopt#

optax.contrib.scale_by_adopt(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.9999, eps: jax.typing.ArrayLike = 1e-06, mu_dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, *, nesterov: bool = False, use_clipping: bool = True, clip_value_fn: ~typing.Callable[[~jax.jaxlib._jax.Array], ~jax.jaxlib._jax.Array] = <function <lambda>>) optax.GradientTransformation[source]#

Rescale updates according to the ADOPT algorithm.

Parameters:
  • b1 โ€“ Decay rate for the exponentially weighted average of grads.

  • b2 โ€“ Decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • mu_dtype โ€“ Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov โ€“ Whether to use Nesterov momentum.

  • use_clipping โ€“ Whether to use gradient clipping to improve stability. When enabled, the clipping value is determined by the clip_value_fn.

  • clip_value_fn โ€“ A function that takes a step index and returns a clipping value. Default is \(x^{0.25}\).

Returns:

A optax.GradientTransformation object.