Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | R | S | T | U | V | W | Y | Z _ __delattr__() (optax.microbatching.Accumulator method) __eq__() (optax.microbatching.Accumulator method) __hash__() (optax.microbatching.Accumulator method) __init__() (optax.microbatching.Accumulator method) __setattr__() (optax.microbatching.Accumulator method) A acc_grads (optax.MultiStepsState attribute) AccumulationType (class in optax.microbatching) Accumulator (class in optax.microbatching) acprop() (in module optax.contrib) adabelief() (in module optax) adadelta() (in module optax) adafactor() (in module optax) adagrad() (in module optax) adam() (in module optax) adamax() (in module optax) adamaxw() (in module optax) adamw() (in module optax) adan() (in module optax) adaptive_grad_clip() (in module optax) AdaptiveGradClipState (in module optax) add_decayed_weights() (in module optax) add_noise() (in module optax) AddDecayedWeightsState (in module optax) AddNoiseState (class in optax) ademamix() (in module optax.contrib) adopt() (in module optax.contrib) adv_state (optax.contrib.SAMState attribute) aggregate (optax.microbatching.Accumulator attribute) amsgrad() (in module optax) apply_every() (in module optax) apply_if_finite() (in module optax) apply_updates() (in module optax) ApplyEvery (class in optax) ApplyIfFiniteState (class in optax) B bias_correction() (in module optax) binary_dice_loss() (in module optax.losses) C cache (optax.contrib.SAMState attribute) centralize() (in module optax) chain() (in module optax) clip() (in module optax) clip_by_block_rms() (in module optax) clip_by_global_norm() (in module optax) ClipByGlobalNormState (in module optax) ClipState (in module optax) cocob() (in module optax.contrib) COCOBState (class in optax.contrib) CONCAT (optax.microbatching.AccumulationType attribute) conditionally_mask() (in module optax) conditionally_transform() (in module optax) ConditionallyMaskState (class in optax) ConditionallyTransformState (class in optax) constant_schedule() (in module optax.schedules) convex_kl_divergence() (in module optax.losses) cosine_decay_schedule() (in module optax.schedules) cosine_distance() (in module optax.losses) cosine_onecycle_schedule() (in module optax.schedules) cosine_similarity() (in module optax.losses) count (optax.contrib.ScaleByAdemamixState attribute) (optax.ScaleByLBFGSState attribute) count_m2 (optax.contrib.ScaleByAdemamixState attribute) ctc_loss() (in module optax.losses) ctc_loss_with_forward_probs() (in module optax.losses) curvature_error (optax.ZoomLinesearchInfo attribute) D dadapt_adamw() (in module optax.contrib) DAdaptAdamWState (class in optax.contrib) decrease_error (optax.ZoomLinesearchInfo attribute) dice_loss() (in module optax.losses) diff_params_memory (optax.ScaleByLBFGSState attribute) diff_updates_memory (optax.ScaleByLBFGSState attribute) differentially_private_aggregate() (in module optax.contrib) DifferentiallyPrivateAggregateState (class in optax.contrib) dog() (in module optax.contrib) DoGState (class in optax.contrib) dowg() (in module optax.contrib) DoWGState (class in optax.contrib) dpsgd() (in module optax.contrib) E ema() (in module optax) EmaState (class in optax) EmptyState (class in optax) exponential_decay() (in module optax.schedules) F FactoredState (class in optax) fast (optax.LookaheadParams attribute) fast_state (optax.LookaheadState attribute) finalize (optax.microbatching.Accumulator attribute) fisher_diag() (in module optax.second_order) flatten() (in module optax) freeze() (in module optax) fromage() (in module optax) G galore() (in module optax.contrib) generalized_kl_divergence() (in module optax.losses) global_norm() (in module optax) grad (optax.ScaleByBacktrackingLinesearchState attribute) (optax.ScaleByZoomLinesearchState attribute) gradient_step (optax.MultiStepsState attribute) GradientTransformation (class in optax) GradientTransformationExtraArgs (class in optax) Gumbel (class in optax.perturbations) H hessian_diag() (in module optax.second_order) hinge_loss() (in module optax.losses) huber_loss() (in module optax.losses) hungarian_algorithm() (in module optax.assignment) hutchinson_estimator_diag_hessian() (in module optax.contrib) HutchinsonState (class in optax.contrib) hvp() (in module optax.second_order) I identity() (in module optax) incremental_update() (in module optax) info (optax.ScaleByBacktrackingLinesearchState attribute) (optax.ScaleByZoomLinesearchState attribute) init (optax.GradientTransformation attribute) (optax.microbatching.Accumulator attribute) inject_hyperparams() (in module optax.schedules) InjectHyperparamsState (class in optax.schedules) inner_opt_state (optax.MultiStepsState attribute) J join_schedules() (in module optax.schedules) K keep_params_nonnegative() (in module optax) kl_divergence() (in module optax.losses) kl_divergence_with_log_targets() (in module optax.losses) L l2_loss() (in module optax.losses) lamb() (in module optax) lars() (in module optax) last_finite (optax.ApplyIfFiniteState attribute) lbfgs() (in module optax) learning_rate (optax.ScaleByBacktrackingLinesearchState attribute) (optax.ScaleByZoomLinesearchState attribute) linear_onecycle_schedule() (in module optax.schedules) linear_schedule() (in module optax.schedules) lion() (in module optax) log_cosh() (in module optax.losses) lookahead() (in module optax) LookaheadParams (class in optax) LookaheadState (class in optax) M m (optax.contrib.ScaleBySimplifiedAdEMAMixState attribute) m1 (optax.contrib.ScaleByAdemamixState attribute) m2 (optax.contrib.ScaleByAdemamixState attribute) madgrad() (in module optax.contrib) MadgradState (class in optax.contrib) make_fenchel_young_loss() (in module optax.losses) make_perturbed_fun() (in module optax.perturbations) masked() (in module optax), [1] MaskedState (class in optax) matrix_inverse_pth_root() (in module optax) MEAN (optax.microbatching.AccumulationType attribute) MechanicState (class in optax.contrib) mechanize() (in module optax.contrib) micro_grad() (in module optax.microbatching) micro_vmap() (in module optax.microbatching) microbatch() (in module optax.microbatching) mini_step (optax.MultiStepsState attribute) momo() (in module optax.contrib) momo_adam() (in module optax.contrib) MomoAdamState (class in optax.contrib) MomoState (class in optax.contrib) multiclass_generalized_dice_loss() (in module optax.losses) multiclass_hinge_loss() (in module optax.losses) multiclass_perceptron_loss() (in module optax.losses) multiclass_sparsemax_loss() (in module optax.losses) MultiSteps (class in optax) MultiStepsState (class in optax) muon() (in module optax.contrib) MuonState (class in optax.contrib) N n (optax.contrib.ScaleBySimplifiedAdEMAMixState attribute) nadam() (in module optax) nadamw() (in module optax) name (optax.tree_utils.NamedTupleKey attribute) named_chain() (in module optax) NamedTupleKey (class in optax.tree_utils) nnls() (in module optax) noisy_sgd() (in module optax) NonNegativeParamsState (in module optax) Normal (class in optax.perturbations) normalize_by_update_norm() (in module optax) notfinite_count (optax.ApplyIfFiniteState attribute) novograd() (in module optax) ntxent() (in module optax.losses) nu (optax.contrib.ScaleByAdemamixState attribute) num_linesearch_steps (optax.ZoomLinesearchInfo attribute) O opt_state (optax.contrib.SAMState attribute) optimistic_adam_v2() (in module optax) optimistic_gradient_descent() (in module optax) OptState (in module optax) P Params (in module optax) params (optax.ScaleByLBFGSState attribute) partition() (in module optax) PartitionState (class in optax) per_example_global_norm_clip() (in module optax) per_example_layer_norm_clip() (in module optax) perceptron_loss() (in module optax.losses) periodic_update() (in module optax) piecewise_constant_schedule() (in module optax.schedules) piecewise_interpolate_schedule() (in module optax.schedules) poly_loss_cross_entropy() (in module optax.losses) polyak_sgd() (in module optax) polynomial_schedule() (in module optax.schedules) power_iteration() (in module optax) prodigy() (in module optax.contrib) ProdigyState (class in optax.contrib) projection_box() (in module optax.projections) projection_halfspace() (in module optax.projections) projection_hypercube() (in module optax.projections) projection_hyperplane() (in module optax.projections) projection_l1_ball() (in module optax.projections) projection_l1_sphere() (in module optax.projections) projection_l2_ball() (in module optax.projections) projection_l2_sphere() (in module optax.projections) projection_linf_ball() (in module optax.projections) projection_non_negative() (in module optax.projections) projection_simplex() (in module optax.projections) projection_vector() (in module optax.projections) R radam() (in module optax) ranking_softmax_loss() (in module optax.losses) reduce_on_plateau() (in module optax.contrib) reshape_batch_axis() (in module optax.microbatching) rmsprop() (in module optax) rprop() (in module optax) RUNNING_MEAN (optax.microbatching.AccumulationType attribute) S safe_increment() (in module optax) safe_norm() (in module optax) safe_root_mean_squares() (in module optax) safe_softmax_cross_entropy() (in module optax.losses) sam() (in module optax.contrib) SAMState (class in optax.contrib) scale() (in module optax) scale_by_acprop() (in module optax.contrib) scale_by_adadelta() (in module optax) scale_by_adam() (in module optax) scale_by_adamax() (in module optax) scale_by_adan() (in module optax) scale_by_ademamix() (in module optax.contrib) scale_by_adopt() (in module optax.contrib) scale_by_amsgrad() (in module optax) scale_by_backtracking_linesearch() (in module optax) scale_by_belief() (in module optax) scale_by_factored_rms() (in module optax) scale_by_lbfgs() (in module optax) scale_by_learning_rate() (in module optax) scale_by_lion() (in module optax) scale_by_madgrad() (in module optax.contrib) scale_by_muon() (in module optax.contrib) scale_by_novograd() (in module optax) scale_by_optimistic_gradient() (in module optax) scale_by_param_block_norm() (in module optax) scale_by_param_block_rms() (in module optax) scale_by_polyak() (in module optax) scale_by_radam() (in module optax) scale_by_rms() (in module optax) scale_by_rprop() (in module optax) scale_by_rss() (in module optax) scale_by_schedule() (in module optax) scale_by_sign() (in module optax) scale_by_simplified_ademamix() (in module optax.contrib) scale_by_sm3() (in module optax) scale_by_stddev() (in module optax) scale_by_trust_ratio() (in module optax) scale_by_yogi() (in module optax) scale_by_zoom_linesearch() (in module optax) scale_gradient() (in module optax) ScaleByAdaDeltaState (class in optax) ScaleByAdamState (class in optax) ScaleByAdanState (class in optax) ScaleByAdemamixState (class in optax.contrib) ScaleByAmsgradState (class in optax) ScaleByBacktrackingLinesearchState (class in optax) ScaleByBeliefState (class in optax) ScaleByLBFGSState (class in optax) ScaleByLionState (class in optax) ScaleByNovogradState (class in optax) ScaleByRmsState (class in optax) ScaleByRpropState (class in optax) ScaleByRssState (class in optax) ScaleByRStdDevState (class in optax) ScaleByScheduleState (class in optax) ScaleBySimplifiedAdEMAMixState (class in optax.contrib) ScaleBySM3State (class in optax) ScaleByTrustRatioState (in module optax) ScaleByZoomLinesearchState (class in optax) ScaleState (in module optax) Schedule (in module optax.schedules) schedule_free() (in module optax.contrib) schedule_free_adamw() (in module optax.contrib) schedule_free_eval_params() (in module optax.contrib) schedule_free_sgd() (in module optax.contrib) ScheduleFreeState (class in optax.contrib) selective_transform() (in module optax) set_to_zero() (in module optax) sgd() (in module optax) sgdr_schedule() (in module optax.schedules) ShouldSkipUpdateFunction (class in optax) sigmoid_binary_cross_entropy() (in module optax.losses) sigmoid_focal_loss() (in module optax.losses) sign_sgd() (in module optax) signum() (in module optax) simplified_ademamix() (in module optax.contrib) skip_large_updates() (in module optax) skip_not_finite() (in module optax) skip_state (optax.MultiStepsState attribute) slow (optax.LookaheadParams attribute) sm3() (in module optax) smooth_labels() (in module optax.losses) snapshot() (in module optax) SnapshotState (class in optax) softmax_cross_entropy() (in module optax.losses) softmax_cross_entropy_with_integer_labels() (in module optax.losses) sophia() (in module optax.contrib) SophiaState (class in optax.contrib) sparsemax_loss() (in module optax.losses) split_real_and_imaginary() (in module optax.contrib) SplitRealAndImaginaryState (class in optax.contrib) squared_error() (in module optax.losses) stateless() (in module optax) stateless_with_tree_map() (in module optax) steps_since_sync (optax.contrib.SAMState attribute) (optax.LookaheadState attribute) SUM (optax.microbatching.AccumulationType attribute) T t (optax.contrib.ScaleBySimplifiedAdEMAMixState attribute) total_notfinite (optax.ApplyIfFiniteState attribute) trace() (in module optax) TraceState (class in optax) TransformInitFn (class in optax) TransformUpdateExtraArgsFn (class in optax) TransformUpdateFn (class in optax) tree_add() (in module optax.tree_utils) tree_add_scale() (in module optax.tree_utils) tree_allclose() (in module optax.tree_utils) tree_batch_shape() (in module optax.tree_utils) tree_cast() (in module optax.tree_utils) tree_cast_like() (in module optax.tree_utils) tree_clip() (in module optax.tree_utils) tree_conj() (in module optax.tree_utils) tree_div() (in module optax.tree_utils) tree_dtype() (in module optax.tree_utils) tree_full_like() (in module optax.tree_utils) tree_get() (in module optax.tree_utils) tree_get_all_with_path() (in module optax.tree_utils) tree_map_params() (in module optax.tree_utils) tree_max() (in module optax.tree_utils) tree_min() (in module optax.tree_utils) tree_mul() (in module optax.tree_utils) tree_norm() (in module optax.tree_utils) tree_ones_like() (in module optax.tree_utils) tree_random_like() (in module optax.tree_utils) tree_real() (in module optax.tree_utils) tree_scale() (in module optax.tree_utils) tree_set() (in module optax.tree_utils) tree_size() (in module optax.tree_utils) tree_split_key_like() (in module optax.tree_utils) tree_sub() (in module optax.tree_utils) tree_sum() (in module optax.tree_utils) tree_vdot() (in module optax.tree_utils) tree_where() (in module optax.tree_utils) tree_zeros_like() (in module optax.tree_utils) triplet_margin_loss() (in module optax.losses) tuple_name (optax.tree_utils.NamedTupleKey attribute) U update (optax.GradientTransformation attribute) (optax.GradientTransformationExtraArgs attribute) (optax.microbatching.Accumulator attribute) update_infinity_moment() (in module optax) update_moment() (in module optax) update_moment_per_elem_norm() (in module optax) Updates (in module optax) updates (optax.ScaleByLBFGSState attribute) V value (optax.ScaleByBacktrackingLinesearchState attribute) (optax.ScaleByZoomLinesearchState attribute) value_and_grad_from_state() (in module optax) W warmup_constant_schedule() (in module optax.schedules) warmup_cosine_decay_schedule() (in module optax.schedules) warmup_exponential_decay_schedule() (in module optax.schedules) weights_memory (optax.ScaleByLBFGSState attribute) with_extra_args_support() (in module optax) Y yogi() (in module optax) Z zero_nans() (in module optax) ZeroNansState (class in optax) ZoomLinesearchInfo (class in optax)