optax.scale_by_factored_rms#
- optax.scale_by_factored_rms(factored: bool = True, decay_rate: jax.typing.ArrayLike = 0.8, step_offset: jax.typing.ArrayLike = 0, min_dim_size_to_factor: int = 128, epsilon: jax.typing.ArrayLike = 1e-30, decay_rate_fn: ~collections.abc.Callable[[TypeAliasForwardRef('jax.typing.ArrayLike'), TypeAliasForwardRef('jax.typing.ArrayLike')], TypeAliasForwardRef('jax.typing.ArrayLike')] = <function _decay_rate_pow>)[source]#
Scaling by a factored estimate of the gradient rms (as in Adafactor).
This is a so-called “1+epsilon” scaling algorithms, that is extremely memory efficient compared to RMSProp/Adam, and has had wide success when applied to large-scale training of attention-based models.
- Parameters:
factored – boolean: whether to use factored second-moment estimates..
decay_rate – float: controls second-moment exponential decay schedule.
step_offset – for finetuning, one may set this to the starting step-number of the fine tuning phase.
min_dim_size_to_factor – only factor accumulator if two array dimensions are at least this size.
epsilon – Regularization constant for squared gradient.
decay_rate_fn – A function that accepts the current step, the decay rate parameter and controls the schedule for the second momentum. Defaults to the original adafactor’s power decay schedule. One potential shortcoming of the original schedule is the fact that second momentum converges to 1, which effectively freezes the second momentum. To prevent this the user can opt for a custom schedule that sets an upper bound for the second momentum, like in Zhai et al., 2021.
- Returns:
The corresponding
optax.GradientTransformation.
References
Shazeer et al, Adafactor: Adaptive Learning Rates with Sublinear Memory Cost, 2018
Zhai et al, Scaling Vision Transformers, 2021