optax.contrib.mechanize#
- optax.contrib.mechanize(base_optimizer: base.GradientTransformation | base.GradientTransformationExtraArgs, weight_decay: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-08, s_init: jax.typing.ArrayLike = 1e-06, num_betas: int = 6) base.GradientTransformationExtraArgs[source]#
Mechanic - a black box learning rate tuner/optimizer.
Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.
Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.
For example, change this:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn)
To:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0) optimizer = optax.adam(learning_rate_fn) optimizer = optax.contrib.mechanize(optimizer)
As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).
- Parameters:
base_optimizer โ Base optimizer to compute updates from.
weight_decay โ A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.
eps โ epsilon for mechanic.
s_init โ initial scale factor. Default should work almost all the time.
num_betas โ unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].
- Returns:
References
Cutkosky et al, Mechanic: A Learning Rate Tuner 2023