optax.lars#
- optax.lars(learning_rate: base.ScalarOrSchedule, weight_decay: base.ScalarOrSchedule = 0.0, weight_decay_mask: MaskOrFn = True, trust_coefficient: jax.typing.ArrayLike = 0.001, eps: jax.typing.ArrayLike = 0.0, trust_ratio_mask: MaskOrFn = True, momentum: jax.typing.ArrayLike = 0.9, nesterov: bool = False) base.GradientTransformationExtraArgs[source]#
The LARS optimizer.
LARS is a layer-wise adaptive optimizer introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimizer.
- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().weight_decay โ Strength of the weight decay regularization.
weight_decay_mask โ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
trust_coefficient โ A multiplier for the trust ratio.
eps โ Optional additive constant in the trust ratio denominator.
trust_ratio_mask โ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
momentum โ Decay rate for momentum.
nesterov โ Whether to use Nesterov momentum.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.lars(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01
References
You et al, Large Batch Training of Convolutional Networks, 2017