Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | R | S | T | U | V | W | X | Y | Z _ __call__() (optax.ShouldSkipUpdateFunction method) (optax.TransformInitFn method) (optax.TransformUpdateFn method) __eq__() (optax.contrib.SAMState method) __getnewargs__() (optax.AddNoiseState method) (optax.ApplyEvery method) (optax.ApplyIfFiniteState method) (optax.contrib.COCOBState method) (optax.contrib.DAdaptAdamWState method) (optax.contrib.DifferentiallyPrivateAggregateState method) (optax.contrib.MechanicState method) (optax.contrib.ProdigyState method) (optax.contrib.SplitRealAndImaginaryState method) (optax.EmaState method) (optax.EmptyState method) (optax.FactoredState method) (optax.GradientTransformation method) (optax.InjectHyperparamsState method) (optax.LookaheadParams method) (optax.LookaheadState method) (optax.MaskedState method) (optax.MaybeUpdateState method) (optax.MultiStepsState method) (optax.MultiTransformState method) (optax.ScaleByAdaDeltaState method) (optax.ScaleByAdamState method) (optax.ScaleByAmsgradState method) (optax.ScaleByBeliefState method) (optax.ScaleByLionState method) (optax.ScaleByNovogradState method) (optax.ScaleByRmsState method) (optax.ScaleByRpropState method) (optax.ScaleByRssState method) (optax.ScaleByRStdDevState method) (optax.ScaleByScheduleState method) (optax.ScaleBySM3State method) (optax.ScaleByTrustRatioState method) (optax.TraceState method) (optax.ZeroNansState method) __getstate__() (optax.contrib.SAMState method) __hash__ (optax.contrib.SAMState attribute) __init__() (optax.contrib.SAMState method) (optax.MultiSteps method) (optax.ShouldSkipUpdateFunction method) (optax.TransformInitFn method) (optax.TransformUpdateFn method) __new__() (optax.EmptyState static method) (optax.ScaleByTrustRatioState static method) __subclasshook__() (optax.ShouldSkipUpdateFunction method) (optax.TransformInitFn method) (optax.TransformUpdateFn method) A acc_grads (optax.MultiStepsState attribute) adabelief() (in module optax) adadelta() (in module optax) adafactor() (in module optax) adagrad() (in module optax) adam() (in module optax) adamax() (in module optax) adamaxw() (in module optax) adamw() (in module optax) adaptive_grad_clip() (in module optax) AdaptiveGradClipState (in module optax) add_decayed_weights() (in module optax) add_noise() (in module optax) AddDecayedWeightsState (in module optax) AddNoiseState (class in optax) adv_state (optax.contrib.SAMState attribute) amsgrad() (in module optax) apply_every() (in module optax) apply_if_finite() (in module optax) apply_updates() (in module optax) ApplyEvery (class in optax) ApplyIfFiniteState (class in optax) B base_optimizer_state (optax.contrib.MechanicState attribute) bias_correction() (in module optax) C cache (optax.contrib.SAMState attribute) centralize() (in module optax) chain() (in module optax) clip() (in module optax) clip_by_block_rms() (in module optax) clip_by_global_norm() (in module optax) ClipByGlobalNormState (in module optax) ClipState (in module optax) cocob() (in module optax.contrib) COCOBState (class in optax.contrib) constant_schedule() (in module optax) control_delta_method() (in module optax.monte_carlo) control_variates_jacobians() (in module optax.monte_carlo) convex_kl_divergence() (in module optax) cosine_decay_schedule() (in module optax) cosine_distance() (in module optax) cosine_onecycle_schedule() (in module optax) cosine_similarity() (in module optax) count (optax.AddNoiseState attribute) (optax.ApplyEvery attribute) (optax.contrib.DAdaptAdamWState attribute) (optax.contrib.MechanicState attribute) (optax.contrib.ProdigyState attribute) (optax.EmaState attribute) (optax.FactoredState attribute) (optax.InjectHyperparamsState attribute) (optax.ScaleByAdamState attribute) (optax.ScaleByAmsgradState attribute) (optax.ScaleByBeliefState attribute) (optax.ScaleByLionState attribute) (optax.ScaleByNovogradState attribute) (optax.ScaleByScheduleState attribute) ctc_loss() (in module optax) ctc_loss_with_forward_probs() (in module optax) cumulative_gradients (optax.contrib.COCOBState attribute) D dadapt_adamw() (in module optax.contrib) DAdaptAdamWState (class in optax.contrib) differentially_private_aggregate() (in module optax.contrib) DifferentiallyPrivateAggregateState (class in optax.contrib) dpsgd() (in module optax.contrib) E e_g (optax.ScaleByAdaDeltaState attribute) e_x (optax.ScaleByAdaDeltaState attribute) ema (optax.EmaState attribute) ema() (in module optax) EmaState (class in optax) EmptyState (class in optax) estim_lr (optax.contrib.DAdaptAdamWState attribute) (optax.contrib.ProdigyState attribute) exp_avg (optax.contrib.DAdaptAdamWState attribute) (optax.contrib.ProdigyState attribute) exp_avg_sq (optax.contrib.DAdaptAdamWState attribute) (optax.contrib.ProdigyState attribute) exponential_decay() (in module optax) F FactoredState (class in optax) fast (optax.LookaheadParams attribute), [1] fast_state (optax.LookaheadState attribute), [1] fisher_diag() (in module optax.second_order) flatten() (in module optax) found_nan (optax.ZeroNansState attribute) fromage() (in module optax) G global_norm() (in module optax) grad (optax.ScaleByBacktrackingLinesearchState attribute) grad_acc (optax.ApplyEvery attribute) grad_sum (optax.contrib.DAdaptAdamWState attribute) (optax.contrib.ProdigyState attribute) gradient_step (optax.MultiStepsState attribute) GradientTransformation (class in optax) GradientTransformationExtraArgs (class in optax) H hessian_diag() (in module optax.second_order) hinge_loss() (in module optax) huber_loss() (in module optax) hvp() (in module optax.second_order) hyperparams (optax.InjectHyperparamsState attribute) I identity() (in module optax) incremental_update() (in module optax) init (optax.GradientTransformation attribute), [1] init() (optax.MultiSteps method) init_particles (optax.contrib.COCOBState attribute) init_synced() (optax.LookaheadParams class method) inject_hyperparams() (in module optax) InjectHyperparamsState (class in optax) inner_opt_state (optax.MultiStepsState attribute) inner_state (optax.ApplyIfFiniteState attribute) (optax.contrib.SplitRealAndImaginaryState attribute) (optax.InjectHyperparamsState attribute) (optax.MaskedState attribute) (optax.MaybeUpdateState attribute) inner_states (optax.MultiTransformState attribute) items() (optax.contrib.SAMState method) J join_schedules() (in module optax) K keep_params_nonnegative() (in module optax) keys() (optax.contrib.SAMState method) kl_divergence() (in module optax) L l2_loss() (in module optax) lamb() (in module optax) lars() (in module optax) last_finite (optax.ApplyIfFiniteState attribute) learning_rate (optax.ScaleByBacktrackingLinesearchState attribute) linear_onecycle_schedule() (in module optax) linear_schedule() (in module optax) lion() (in module optax) log_cosh() (in module optax) lookahead() (in module optax) LookaheadParams (class in optax) LookaheadState (class in optax) M m (optax.contrib.MechanicState attribute) masked() (in module optax) MaskedState (class in optax) matrix_inverse_pth_root() (in module optax) maybe_update() (in module optax) MaybeUpdateState (class in optax) measure_valued_jacobians() (in module optax.monte_carlo) MechanicState (class in optax.contrib) mechanize() (in module optax.contrib) mini_step (optax.MultiStepsState attribute) moving_avg_baseline() (in module optax.monte_carlo) mu (optax.ScaleByAdamState attribute) (optax.ScaleByAmsgradState attribute) (optax.ScaleByBeliefState attribute) (optax.ScaleByLionState attribute) (optax.ScaleByNovogradState attribute) (optax.ScaleByRStdDevState attribute) (optax.ScaleBySM3State attribute) multi_normal() (in module optax) multi_transform() (in module optax) MultiSteps (class in optax) MultiStepsState (class in optax) MultiTransformState (class in optax) N nadam() (in module optax) nadamw() (in module optax) name (optax.tree_utils.NamedTupleKey attribute) NamedTupleKey (class in optax.tree_utils) noisy_sgd() (in module optax) NonNegativeParamsState (in module optax) notfinite_count (optax.ApplyIfFiniteState attribute) novograd() (in module optax) nu (optax.ScaleByAdamState attribute) (optax.ScaleByAmsgradState attribute) (optax.ScaleByBeliefState attribute) (optax.ScaleByNovogradState attribute) (optax.ScaleByRmsState attribute) (optax.ScaleByRStdDevState attribute) (optax.ScaleBySM3State attribute) nu_max (optax.ScaleByAmsgradState attribute) numerator_weighted (optax.contrib.DAdaptAdamWState attribute) (optax.contrib.ProdigyState attribute) O opt_state (optax.contrib.SAMState attribute) optimistic_gradient_descent() (in module optax) OptState (in module optax) P Params (in module optax) params0 (optax.contrib.ProdigyState attribute) pathwise_jacobians() (in module optax.monte_carlo) per_example_global_norm_clip() (in module optax) per_example_layer_norm_clip() (in module optax) periodic_update() (in module optax) piecewise_constant_schedule() (in module optax) piecewise_interpolate_schedule() (in module optax) polyak_sgd() (in module optax) polynomial_schedule() (in module optax) power_iteration() (in module optax) prev_updates (optax.ScaleByRpropState attribute) prodigy() (in module optax.contrib) ProdigyState (class in optax.contrib) R r (optax.contrib.MechanicState attribute) radam() (in module optax) reduce_on_plateau() (in module optax.contrib) reward (optax.contrib.COCOBState attribute) rmsprop() (in module optax) rng_key (optax.AddNoiseState attribute) (optax.contrib.DifferentiallyPrivateAggregateState attribute) rprop() (in module optax) S s (optax.contrib.MechanicState attribute) safe_int32_increment() (in module optax) safe_norm() (in module optax) safe_root_mean_squares() (in module optax) sam() (in module optax.contrib) SAMState (class in optax.contrib) scale (optax.contrib.COCOBState attribute) scale() (in module optax) scale_by_adadelta() (in module optax) scale_by_adam() (in module optax) scale_by_adamax() (in module optax) scale_by_amsgrad() (in module optax) scale_by_backtracking_linesearch() (in module optax) scale_by_belief() (in module optax) scale_by_factored_rms() (in module optax) scale_by_learning_rate() (in module optax) scale_by_lion() (in module optax) scale_by_novograd() (in module optax) scale_by_optimistic_gradient() (in module optax) scale_by_param_block_norm() (in module optax) scale_by_param_block_rms() (in module optax) scale_by_polyak() (in module optax) scale_by_radam() (in module optax) scale_by_rms() (in module optax) scale_by_rprop() (in module optax) scale_by_rss() (in module optax) scale_by_schedule() (in module optax) scale_by_sm3() (in module optax) scale_by_stddev() (in module optax) scale_by_trust_ratio() (in module optax) scale_by_yogi() (in module optax) scale_gradient() (in module optax) ScaleByAdaDeltaState (class in optax) ScaleByAdamState (class in optax) ScaleByAmsgradState (class in optax) ScaleByBacktrackingLinesearchState (class in optax) ScaleByBeliefState (class in optax) ScaleByLionState (class in optax) ScaleByNovogradState (class in optax) ScaleByRmsState (class in optax) ScaleByRpropState (class in optax) ScaleByRssState (class in optax) ScaleByRStdDevState (class in optax) ScaleByScheduleState (class in optax) ScaleBySM3State (class in optax) ScaleByTrustRatioState (class in optax) ScaleState (in module optax) Schedule (in module optax) score_function_jacobians() (in module optax.monte_carlo) set_to_zero() (in module optax) sgd() (in module optax) sgdr_schedule() (in module optax) ShouldSkipUpdateFunction (class in optax) sigmoid_binary_cross_entropy() (in module optax) sigmoid_focal_loss() (in module optax) skip_large_updates() (in module optax) skip_not_finite() (in module optax) skip_state (optax.MultiStepsState attribute) slow (optax.LookaheadParams attribute), [1] sm3() (in module optax) smooth_labels() (in module optax) softmax_cross_entropy() (in module optax) softmax_cross_entropy_with_integer_labels() (in module optax) split_real_and_imaginary() (in module optax.contrib) SplitRealAndImaginaryState (class in optax.contrib) squared_error() (in module optax) stateless() (in module optax) stateless_with_tree_map() (in module optax) step (optax.MaybeUpdateState attribute) step_sizes (optax.ScaleByRpropState attribute) steps_since_sync (optax.contrib.SAMState attribute) (optax.LookaheadState attribute), [1] subgradients (optax.contrib.COCOBState attribute) sum_of_squares (optax.ScaleByRssState attribute) T total_notfinite (optax.ApplyIfFiniteState attribute) trace (optax.TraceState attribute) trace() (in module optax) TraceState (class in optax) TransformInitFn (class in optax) TransformUpdateFn (class in optax) tree_add() (in module optax.tree_utils) tree_add_scalar_mul() (in module optax.tree_utils) tree_div() (in module optax.tree_utils) tree_get() (in module optax.tree_utils) tree_get_all_with_path() (in module optax.tree_utils) tree_l2_norm() (in module optax.tree_utils) tree_map_params() (in module optax.tree_utils) tree_mul() (in module optax.tree_utils) tree_ones_like() (in module optax.tree_utils) tree_random_like() (in module optax.tree_utils) tree_scalar_mul() (in module optax.tree_utils) tree_set() (in module optax.tree_utils) tree_sub() (in module optax.tree_utils) tree_sum() (in module optax.tree_utils) tree_vdot() (in module optax.tree_utils) tree_zeros_like() (in module optax.tree_utils) tuple_name (optax.tree_utils.NamedTupleKey attribute) U update (optax.GradientTransformation attribute), [1] (optax.GradientTransformationExtraArgs attribute) update() (optax.MultiSteps method) update_infinity_moment() (in module optax) update_moment() (in module optax) update_moment_per_elem_norm() (in module optax) Updates (in module optax) V v (optax.contrib.MechanicState attribute) (optax.FactoredState attribute) v_col (optax.FactoredState attribute) v_row (optax.FactoredState attribute) value (optax.ScaleByBacktrackingLinesearchState attribute) value_and_grad_from_state() (in module optax) values() (optax.contrib.SAMState method) W warmup_cosine_decay_schedule() (in module optax) warmup_exponential_decay_schedule() (in module optax) with_extra_args_support() (in module optax) X x0 (optax.contrib.MechanicState attribute) Y yogi() (in module optax) Z zero_nans() (in module optax) ZeroNansState (class in optax)