Skip to main content
Ctrl+K
Optax  documentation - Home Optax  documentation - Home
  • ๐Ÿš€ Getting started
  • ๐Ÿ–ผ๏ธ Example gallery
    • Examples
      • Adversarial training
      • ResNet on CIFAR10 with Flax NNX and Optax.
      • Simple NN with Flax.
      • Freezing Parameters in Optax
      • Gradient Accumulation
      • Summary
      • L-BFGS
      • Linear assignment problem
      • Lookahead Optimizer on MNIST
      • Meta-Learning
      • MLP MNIST
      • Character-level Transformer on Tiny Shakespeare
      • Optimistic Gradient Descent in a Bilinear Min-Max Problem
      • Perturbed optimizers
    • Contrib Examples
      • Differentially private convolutional neural network on MNIST.
      • Using the Muon Optimizer in Optax
      • Reduce on Plateau Learning Rate Scheduler
      • Recreate AdeMAMix Rosenbrock Plot from Paper
      • Sharpness-Aware Minimization (SAM)
  • ๐Ÿ› ๏ธ Development

๐Ÿ“– Reference

  • Assignment problem
    • optax.assignment.hungarian_algorithm
  • Optimizers
    • optax.adabelief
    • optax.adadelta
    • optax.adan
    • optax.adafactor
    • optax.adagrad
    • optax.adam
    • optax.adamw
    • optax.adamax
    • optax.adamaxw
    • optax.amsgrad
    • optax.fromage
    • optax.lamb
    • optax.lars
    • optax.lbfgs
    • optax.lion
    • optax.nadam
    • optax.nadamw
    • optax.noisy_sgd
    • optax.novograd
    • optax.optimistic_gradient_descent
    • optax.optimistic_adam_v2
    • optax.polyak_sgd
    • optax.radam
    • optax.rmsprop
    • optax.sgd
    • optax.sign_sgd
    • optax.signum
    • optax.sm3
    • optax.yogi
    • optax.rprop
  • Transformations
    • optax.adaptive_grad_clip
    • optax.AdaptiveGradClipState
    • optax.add_decayed_weights
    • optax.AddDecayedWeightsState
    • optax.add_noise
    • optax.AddNoiseState
    • optax.apply_every
    • optax.ApplyEvery
    • optax.bias_correction
    • optax.conditionally_mask
    • optax.conditionally_transform
    • optax.ConditionallyMaskState
    • optax.ConditionallyTransformState
    • optax.centralize
    • optax.clip
    • optax.clip_by_block_rms
    • optax.ClipState
    • optax.clip_by_global_norm
    • optax.ClipByGlobalNormState
    • optax.ema
    • optax.EmaState
    • optax.EmptyState
    • optax.global_norm
    • optax.GradientTransformation
    • optax.GradientTransformationExtraArgs
    • optax.identity
    • optax.init_empty_state
    • optax.keep_params_nonnegative
    • optax.measure_with_ema
    • optax.monitor
    • optax.MonitorState
    • optax.NonNegativeParamsState
    • optax.normalize_by_update_norm
    • optax.OptState
    • optax.Params
    • optax.per_example_global_norm_clip
    • optax.per_example_layer_norm_clip
    • optax.scale
    • optax.ScaleState
    • optax.scale_by_adadelta
    • optax.ScaleByAdaDeltaState
    • optax.scale_by_adan
    • optax.ScaleByAdanState
    • optax.scale_by_adam
    • optax.scale_by_adamax
    • optax.ScaleByAdamState
    • optax.scale_by_amsgrad
    • optax.ScaleByAmsgradState
    • optax.scale_by_backtracking_linesearch
    • optax.ScaleByBacktrackingLinesearchState
    • optax.scale_by_belief
    • optax.ScaleByBeliefState
    • optax.scale_by_factored_rms
    • optax.FactoredState
    • optax.scale_by_lbfgs
    • optax.ScaleByLBFGSState
    • optax.scale_by_learning_rate
    • optax.scale_by_lion
    • optax.ScaleByLionState
    • optax.scale_by_novograd
    • optax.ScaleByNovogradState
    • optax.scale_by_optimistic_gradient
    • optax.scale_by_param_block_norm
    • optax.scale_by_param_block_rms
    • optax.scale_by_polyak
    • optax.scale_by_radam
    • optax.scale_by_rms
    • optax.ScaleByRmsState
    • optax.scale_by_rprop
    • optax.ScaleByRpropState
    • optax.scale_by_rss
    • optax.ScaleByRssState
    • optax.scale_by_schedule
    • optax.ScaleByScheduleState
    • optax.scale_by_sign
    • optax.scale_by_sm3
    • optax.ScaleBySM3State
    • optax.scale_by_stddev
    • optax.ScaleByRStdDevState
    • optax.scale_by_trust_ratio
    • optax.ScaleByTrustRatioState
    • optax.scale_by_yogi
    • optax.scale_by_zoom_linesearch
    • optax.ScaleByZoomLinesearchState
    • optax.set_to_zero
    • optax.snapshot
    • optax.SnapshotState
    • optax.stateless
    • optax.stateless_with_tree_map
    • optax.trace
    • optax.TraceState
    • optax.TransformInitFn
    • optax.TransformUpdateFn
    • optax.TransformUpdateExtraArgsFn
    • optax.update_infinity_moment
    • optax.update_moment
    • optax.update_moment_per_elem_norm
    • optax.Updates
    • optax.with_extra_args_support
    • optax.zero_nans
    • optax.ZeroNansState
    • optax.ZoomLinesearchInfo
    • optax.masked
    • optax.freeze
    • optax.selective_transform
  • Combining Optimizers
    • optax.chain
    • optax.named_chain
    • optax.partition
    • optax.PartitionState
  • Optimizer Wrappers
    • optax.apply_if_finite
    • optax.ApplyIfFiniteState
    • optax.flatten
    • optax.lookahead
    • optax.LookaheadParams
    • optax.LookaheadState
    • optax.masked
    • optax.MaskedState
    • optax.MultiSteps
    • optax.MultiStepsState
    • optax.ShouldSkipUpdateFunction
    • optax.skip_large_updates
    • optax.skip_not_finite
  • Optimizer Schedules
    • optax.schedules.constant_schedule
    • optax.schedules.cosine_decay_schedule
    • optax.schedules.cosine_onecycle_schedule
    • optax.schedules.exponential_decay
    • optax.schedules.join_schedules
    • optax.schedules.linear_onecycle_schedule
    • optax.schedules.linear_schedule
    • optax.schedules.piecewise_constant_schedule
    • optax.schedules.piecewise_interpolate_schedule
    • optax.schedules.polynomial_schedule
    • optax.schedules.sgdr_schedule
    • optax.schedules.warmup_constant_schedule
    • optax.schedules.warmup_cosine_decay_schedule
    • optax.schedules.warmup_exponential_decay_schedule
    • optax.schedules.Schedule
    • optax.schedules.InjectHyperparamsState
    • optax.schedules.inject_hyperparams
    • optax.contrib.reduce_on_plateau
  • Apply Updates
    • optax.apply_updates
    • optax.incremental_update
    • optax.periodic_update
  • Perturbations
    • optax.perturbations.make_perturbed_fun
    • optax.perturbations.Gumbel
    • optax.perturbations.Normal
  • Projections
    • optax.projections.projection_box
    • optax.projections.projection_hypercube
    • optax.projections.projection_l1_ball
    • optax.projections.projection_l1_sphere
    • optax.projections.projection_l2_ball
    • optax.projections.projection_l2_sphere
    • optax.projections.projection_linf_ball
    • optax.projections.projection_non_negative
    • optax.projections.projection_simplex
    • optax.projections.projection_vector
    • optax.projections.projection_hyperplane
    • optax.projections.projection_halfspace
  • Losses
    • optax.losses.binary_dice_loss
    • optax.losses.convex_kl_divergence
    • optax.losses.cosine_distance
    • optax.losses.cosine_similarity
    • optax.losses.ctc_loss
    • optax.losses.ctc_loss_with_forward_probs
    • optax.losses.dice_loss
    • optax.losses.generalized_kl_divergence
    • optax.losses.hinge_loss
    • optax.losses.huber_loss
    • optax.losses.kl_divergence
    • optax.losses.kl_divergence_with_log_targets
    • optax.losses.l2_loss
    • optax.losses.log_cosh
    • optax.losses.make_fenchel_young_loss
    • optax.losses.multiclass_generalized_dice_loss
    • optax.losses.multiclass_hinge_loss
    • optax.losses.multiclass_perceptron_loss
    • optax.losses.multiclass_sparsemax_loss
    • optax.losses.ntxent
    • optax.losses.perceptron_loss
    • optax.losses.poly_loss_cross_entropy
    • optax.losses.ranking_softmax_loss
    • optax.losses.safe_softmax_cross_entropy
    • optax.losses.sigmoid_binary_cross_entropy
    • optax.losses.sigmoid_focal_loss
    • optax.losses.smooth_labels
    • optax.losses.softmax_cross_entropy
    • optax.losses.softmax_cross_entropy_with_integer_labels
    • optax.losses.sparsemax_loss
    • optax.losses.squared_error
    • optax.losses.triplet_margin_loss
  • Utilities
    • optax.scale_gradient
    • optax.value_and_grad_from_state
    • optax.safe_increment
    • optax.safe_norm
    • optax.safe_root_mean_squares
    • optax.matrix_inverse_pth_root
    • optax.power_iteration
    • optax.nnls
    • optax.second_order.fisher_diag
    • optax.second_order.hessian_diag
    • optax.second_order.hvp
    • optax.tree_utils.NamedTupleKey
    • optax.tree_utils.tree_add
    • optax.tree_utils.tree_add_scale
    • optax.tree_utils.tree_allclose
    • optax.tree_utils.tree_batch_shape
    • optax.tree_utils.tree_cast
    • optax.tree_utils.tree_cast_like
    • optax.tree_utils.tree_clip
    • optax.tree_utils.tree_conj
    • optax.tree_utils.tree_div
    • optax.tree_utils.tree_dtype
    • optax.tree_utils.tree_full_like
    • optax.tree_utils.tree_get
    • optax.tree_utils.tree_get_all_with_path
    • optax.tree_utils.tree_norm
    • optax.tree_utils.tree_map_params
    • optax.tree_utils.tree_max
    • optax.tree_utils.tree_min
    • optax.tree_utils.tree_mul
    • optax.tree_utils.tree_ones_like
    • optax.tree_utils.tree_random_like
    • optax.tree_utils.tree_real
    • optax.tree_utils.tree_split_key_like
    • optax.tree_utils.tree_scale
    • optax.tree_utils.tree_set
    • optax.tree_utils.tree_size
    • optax.tree_utils.tree_sub
    • optax.tree_utils.tree_sum
    • optax.tree_utils.tree_vdot
    • optax.tree_utils.tree_where
    • optax.tree_utils.tree_zeros_like
  • Microbatching
    • optax.microbatching.microbatch
    • optax.microbatching.micro_vmap
    • optax.microbatching.micro_grad
    • optax.microbatching.reshape_batch_axis
    • optax.microbatching.AccumulationType
    • optax.microbatching.Accumulator
  • ๐Ÿ”ง Contrib
    • optax.contrib.acprop
    • optax.contrib.ademamix
    • optax.contrib.adopt
    • optax.contrib.simplified_ademamix
    • optax.contrib.cocob
    • optax.contrib.COCOBState
    • optax.contrib.dadapt_adamw
    • optax.contrib.DAdaptAdamWState
    • optax.contrib.differentially_private_aggregate
    • optax.contrib.DifferentiallyPrivateAggregateState
    • optax.contrib.dog
    • optax.contrib.DoGState
    • optax.contrib.dowg
    • optax.contrib.DoWGState
    • optax.contrib.dpsgd
    • optax.contrib.galore
    • optax.contrib.GaLoreState
    • optax.contrib.madgrad
    • optax.contrib.MadgradState
    • optax.contrib.mechanize
    • optax.contrib.MechanicState
    • optax.contrib.momo
    • optax.contrib.MomoState
    • optax.contrib.momo_adam
    • optax.contrib.MomoAdamState
    • optax.contrib.muon
    • optax.contrib.MuonState
    • optax.contrib.prodigy
    • optax.contrib.ProdigyState
    • optax.contrib.sam
    • optax.contrib.SAMState
    • optax.contrib.schedule_free
    • optax.contrib.schedule_free_adamw
    • optax.contrib.schedule_free_eval_params
    • optax.contrib.schedule_free_sgd
    • optax.contrib.ScheduleFreeState
    • optax.contrib.sophia
    • optax.contrib.SophiaState
    • optax.contrib.split_real_and_imaginary
    • optax.contrib.SplitRealAndImaginaryState
    • optax.contrib.scale_by_ademamix
    • optax.contrib.ScaleByAdemamixState
    • optax.contrib.scale_by_simplified_ademamix
    • optax.contrib.ScaleBySimplifiedAdEMAMixState
    • optax.contrib.scale_by_adopt
    • optax.contrib.scale_by_acprop
    • optax.contrib.scale_by_madgrad
    • optax.contrib.scale_by_muon
    • optax.contrib.hutchinson_estimator_diag_hessian
    • optax.contrib.HutchinsonState
  • ๐Ÿงช Experimental
  • .md

Contrib Examples

Contrib Examples#

Examples that make use of the optax.contrib module.

  • Differentially private convolutional neural network on MNIST.
  • Using the Muon Optimizer in Optax
  • Reduce on Plateau Learning Rate Scheduler
  • Recreate AdeMAMix Rosenbrock Plot from Paper
  • Sharpness-Aware Minimization (SAM)

previous

Perturbed optimizers

next

Differentially private convolutional neural network on MNIST.

By Optax Contributors

ยฉ Copyright 2021, DeepMind.