🔧 Contrib

🔧 Contrib#

Algorithms or wrappers that don’t meet (yet) the Inclusion Criteria or are not supported by the main library.

acprop(learning_rate[, b1, b2, eps, ...])

The ACProp optimizer.

ademamix(learning_rate[, b1, b2, b3, alpha, ...])

AdEMAMix.

adopt(learning_rate, b1, b2, eps, mu_dtype, ...)

ADOPT (Adaptive Optimization with Provable Theoretical guarantees).

simplified_ademamix(learning_rate[, b1, b2, ...])

Simplified AdEMAMix.

cocob([learning_rate, alpha, eps, ...])

Rescale updates according to the COntinuous COin Betting algorithm.

COCOBState(init_particles, ...)

State for COntinuous COin Betting.

dadapt_adamw([learning_rate, betas, eps, ...])

Learning rate free AdamW by D-Adaptation.

DAdaptAdamWState(exp_avg, exp_avg_sq, ...)

State of the GradientTransformation returned by dadapt_adamw.

differentially_private_aggregate(...[, key, ...])

Aggregates gradients based on the DPSGD algorithm.

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

dog([learning_rate, init_step, eps, ...])

Distance over Gradients (DoG) optimizer.

DoGState(is_init_step, init_params, ...)

State for DoG optimizer.

dowg([learning_rate, init_estim_sq_dist, ...])

Distance over weighted Gradients optimizer.

DoWGState(init_params, ...)

State for DoWG optimizer.

dpsgd(learning_rate, l2_norm_clip, ...[, ...])

The DPSGD optimizer.

galore(learning_rate[, rank, ...])

GaLore: Memory-efficient training via gradient lowrank projection.

GaLoreState(count, base_optimizer_state, ...)

State for the GaLore optimizer.

madgrad(learning_rate[, momentum, ...])

The MADGRAD optimizer.

MadgradState(count, grad_sum_sq, s, x0)

State for the MADGRAD optimizer.

mechanize(base_optimizer[, weight_decay, ...])

Mechanic - a black box learning rate tuner/optimizer.

MechanicState(base_optimizer_state, count, ...)

State of the GradientTransformation returned by mechanize.

momo([learning_rate, beta, lower_bound, ...])

Adaptive Learning Rates for SGD with momentum.

MomoState(exp_avg, barf, gamma, lb, count)

State of the GradientTransformation returned by momo.

momo_adam([learning_rate, b1, b2, eps, ...])

Adaptive Learning Rates for Adam(W).

MomoAdamState(exp_avg, exp_avg_sq, barf, ...)

State of the GradientTransformation returned by momo_adam.

muon(learning_rate[, ns_coeffs, ns_steps, ...])

Muon: Momentum Orthogonalized by Newton-schulz.

MuonState(count, mu, ns_coeffs)

State for the Muon algorithm.

prodigy([learning_rate, betas, beta3, eps, ...])

Learning rate free AdamW with Prodigy.

ProdigyState(exp_avg, exp_avg_sq, grad_sum, ...)

State of the GradientTransformation returned by prodigy.

sam(optimizer, adv_optimizer[, sync_period, ...])

Implementation of SAM (Sharpness Aware Minimization).

SAMState(steps_since_sync, opt_state, ...)

State of GradientTransformation returned by sam.

schedule_free(base_optimizer, learning_rate)

Turn base_optimizer schedule_free.

schedule_free_adamw([learning_rate, ...])

Schedule-Free wrapper for AdamW.

schedule_free_eval_params(state, params)

Params for evaluation of optax.contrib.schedule_free().

schedule_free_sgd([learning_rate, ...])

Schedule-Free wrapper for SGD.

ScheduleFreeState(b1, weight_sum, ...)

State for schedule_free.

sophia(learning_rate, b1, b2, eps, ...)

Sophia optimizer.

SophiaState(count, mu, nu, hessian_fn_state)

State for Sophia Optimizer.

split_real_and_imaginary(inner)

Splits the real and imaginary components of complex updates into two.

SplitRealAndImaginaryState(inner_state)

Maintains the inner transformation state for split_real_and_imaginary.

scale_by_ademamix([b1, b2, b3, alpha, eps, ...])

Scale updates according to the Ademamix algorithm.

ScaleByAdemamixState(count, count_m2, m1, m2, nu)

State for the Ademamix algorithm.

scale_by_simplified_ademamix([b1, b2, ...])

Scale updates according to the Simplified AdEMAMix optimizer.

ScaleBySimplifiedAdEMAMixState(t, m, n)

State for the Simplified AdEMAMix optimizer.

scale_by_adopt(b1, b2, eps, mu_dtype, *, ...)

Rescale updates according to the ADOPT algorithm.

scale_by_acprop([b1, b2, eps, eps_root])

Rescale updates according to ACProp (asynchronous version of AdaBelief).

scale_by_madgrad(learning_rate[, momentum, eps])

Rescale updates according to the MADGRAD algorithm.

scale_by_muon([ns_coeffs, ns_steps, beta, ...])

Rescale updates according to the Muon algorithm.

hutchinson_estimator_diag_hessian([random_seed])

Returns a GradientTransformationExtraArgs computing the Hessian diagonal.

HutchinsonState(key)