optax.rmsprop#
- optax.rmsprop(learning_rate: base.ScalarOrSchedule, decay: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-08, initial_scale: jax.typing.ArrayLike = 0.0, eps_in_sqrt: bool = True, centered: bool = False, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False, bias_correction: bool = False) base.GradientTransformationExtraArgs[source]#
A flexible RMSProp optimizer.
RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().decay – Decay used to track the magnitude of previous gradients.
eps – A small numerical constant to avoid dividing by zero when rescaling.
initial_scale – Initial value of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.
eps_in_sqrt – Whether to add
epsin the square root of the denominator or outside the square root.centered – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.
momentum – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.
nesterov – Whether Nesterov momentum is used.
bias_correction – Whether to apply bias correction to the estimates of the second moments (and first moment if
centered=True).
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.rmsprop(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
Hinton, Overview of mini-batch gradient descent <www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_, 2012
Graves, Generating Sequences With Recurrent Neural Networks, 2014
Ziyin, LaProp: Separating Momentum and Adaptivity in Adam, 2021
Warning
Default behavior of optax’s RMSprop (
eps_in_sqrt=True) differs from Pytorch’s implementation and could impact performance. Ifeps_in_sqrt=True, in the denominator, optax uses \(\sqrt{v + \epsilon}\) in the denominator whereas PyTorch uses \(\sqrt{v} + \epsilon\). Usingeps_in_sqrt=Falsein optax will match PyTorch’s behavior. See google-deepmind/optax#532 for more detail.