optax.rprop#
- optax.rprop(learning_rate: jax.typing.ArrayLike, eta_minus: jax.typing.ArrayLike = 0.5, eta_plus: jax.typing.ArrayLike = 1.2, min_step_size: jax.typing.ArrayLike = 1e-06, max_step_size: jax.typing.ArrayLike = 50.0) base.GradientTransformationExtraArgs[source]#
The Rprop optimizer.
Rprop, short for resillient backpropogation, is a first order variant of gradient descent. It responds only to the sign of the gradient by increasing or decreasing the step size selected per parameter exponentially to speed up convergence and avoid oscillations.
- Parameters:
learning_rate โ The initial step size.
eta_minus โ Multiplicative factor for decreasing step size. This is applied when the gradient changes sign from one step to the next.
eta_plus โ Multiplicative factor for increasing step size. This is applied when the gradient has the same sign from one step to the next.
min_step_size โ Minimum allowed step size. Smaller steps will be clipped to this value.
max_step_size โ Maximum allowed step size. Larger steps will be clipped to this value.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.rprop(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Riedmiller et al. A direct adaptive method for faster backpropagation learning: the RPROP algorithm, 1993
Igel et al. Empirical evaluation of the improved Rprop learning algorithms, 2003