optax.scale_by_yogi#
- optax.scale_by_yogi(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 0.001, eps_root: jax.typing.ArrayLike = 0.0, initial_accumulator_value: jax.typing.ArrayLike = 1e-06) optax.GradientTransformation[source]#
Rescale updates according to the Yogi algorithm.
See
optax.yogi()for more details.Supports complex numbers, see https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
- Parameters:
b1 โ Decay rate for the exponentially weighted average of grads.
b2 โ Decay rate for the exponentially weighted average of variance of grads.
eps โ Term added to the denominator to improve numerical stability.
eps_root โ Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
initial_accumulator_value โ The starting value for accumulators. Only positive values are allowed.
- Returns:
A
optax.GradientTransformationobject.