optax.lars

Contents

optax.lars#

optax.lars(learning_rate: base.ScalarOrSchedule, weight_decay: base.ScalarOrSchedule = 0.0, weight_decay_mask: MaskOrFn = True, trust_coefficient: jax.typing.ArrayLike = 0.001, eps: jax.typing.ArrayLike = 0.0, trust_ratio_mask: MaskOrFn = True, momentum: jax.typing.ArrayLike = 0.9, nesterov: bool = False) base.GradientTransformationExtraArgs[source]#

The LARS optimizer.

LARS is a layer-wise adaptive optimizer introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimizer.

Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • weight_decay โ€“ Strength of the weight decay regularization.

  • weight_decay_mask โ€“ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • trust_coefficient โ€“ A multiplier for the trust ratio.

  • eps โ€“ Optional additive constant in the trust ratio denominator.

  • trust_ratio_mask โ€“ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • momentum โ€“ Decay rate for momentum.

  • nesterov โ€“ Whether to use Nesterov momentum.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lars(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01

References

You et al, Large Batch Training of Convolutional Networks, 2017