optax.scale_by_optimistic_gradient#
- optax.scale_by_optimistic_gradient(alpha: jax.typing.ArrayLike = 1.0, beta: jax.typing.ArrayLike = 1.0) optax.GradientTransformation[source]#
Compute generalized optimistic gradients.
See
optax.optimistic_adam_v2(),optax.optimistic_gradient_descent()for more details.- Parameters:
alpha โ Coefficient for generalized optimistic gradient descent.
beta โ Coefficient for negative momentum.
- Returns:
A
optax.GradientTransformationobject.