optax.adafactor#
- optax.adafactor(learning_rate: base.ScalarOrSchedule | None = None, min_dim_size_to_factor: int = 128, decay_rate: jax.typing.ArrayLike = 0.8, decay_offset: jax.typing.ArrayLike = 0, multiply_by_parameter_scale: bool = True, clipping_threshold: jax.typing.ArrayLike | None = 1.0, momentum: jax.typing.ArrayLike | None = None, dtype_momentum: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, weight_decay_rate: base.ScalarOrSchedule | None = None, eps: jax.typing.ArrayLike = 1e-30, factored: bool = True, weight_decay_mask: MaskOrFn = None) base.GradientTransformationExtraArgs[source]#
The Adafactor optimizer.
Adafactor is an adaptive learning rate optimizer that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients.
- Parameters:
learning_rate β A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate(). Note that the natural scale for Adafactorβs LR is markedly different from Adam, one doesnβt use the 1/sqrt(hidden) correction for this optim with attention-based models.min_dim_size_to_factor β Only factor the statistics if two array dimensions have at least this size.
decay_rate β Controls second-moment exponential decay schedule.
decay_offset β For fine-tuning, one may set this to the starting step number of the fine-tuning phase.
multiply_by_parameter_scale β If True, then scale learning_rate by parameter norm. If False, provided learning_rate is absolute step size.
clipping_threshold β Optional clipping threshold. Must be >= 1. If None, clipping is disabled.
momentum β Optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default.
dtype_momentum β Data type of momentum buffers.
weight_decay_rate β Optional rate at which to decay weights.
eps β Regularization constant for root mean squared gradient.
factored β Whether to use factored second-moment estimates.
weight_decay_mask β A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adafactor(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
Shazeer et al, Adafactor: Adaptive Learning Rates with Sublinear Memory Cost, 2018