optax.adafactor

Contents

optax.adafactor#

optax.adafactor(learning_rate: base.ScalarOrSchedule | None = None, min_dim_size_to_factor: int = 128, decay_rate: jax.typing.ArrayLike = 0.8, decay_offset: jax.typing.ArrayLike = 0, multiply_by_parameter_scale: bool = True, clipping_threshold: jax.typing.ArrayLike | None = 1.0, momentum: jax.typing.ArrayLike | None = None, dtype_momentum: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, weight_decay_rate: base.ScalarOrSchedule | None = None, eps: jax.typing.ArrayLike = 1e-30, factored: bool = True, weight_decay_mask: MaskOrFn = None) base.GradientTransformationExtraArgs[source]#

The Adafactor optimizer.

Adafactor is an adaptive learning rate optimizer that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Note that the natural scale for Adafactor’s LR is markedly different from Adam, one doesn’t use the 1/sqrt(hidden) correction for this optim with attention-based models.

  • min_dim_size_to_factor – Only factor the statistics if two array dimensions have at least this size.

  • decay_rate – Controls second-moment exponential decay schedule.

  • decay_offset – For fine-tuning, one may set this to the starting step number of the fine-tuning phase.

  • multiply_by_parameter_scale – If True, then scale learning_rate by parameter norm. If False, provided learning_rate is absolute step size.

  • clipping_threshold – Optional clipping threshold. Must be >= 1. If None, clipping is disabled.

  • momentum – Optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default.

  • dtype_momentum – Data type of momentum buffers.

  • weight_decay_rate – Optional rate at which to decay weights.

  • eps – Regularization constant for root mean squared gradient.

  • factored – Whether to use factored second-moment estimates.

  • weight_decay_mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adafactor(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

References

Shazeer et al, Adafactor: Adaptive Learning Rates with Sublinear Memory Cost, 2018