Toggle navigation sidebar
Toggle in-page Table of Contents
Getting Started
Learn Optax
Examples
Gradient Accumulation
Meta-Learning
Developer Documentation
Design Documents
Contributors
API Documentation
Common Optimizers
Index
_
|
A
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
R
|
S
|
T
|
U
|
V
|
W
|
Y
|
Z
_
__call__() (optax.TransformInitFn method)
(optax.TransformUpdateFn method)
__getnewargs__() (optax.AddNoiseState method)
(optax.ApplyEvery method)
(optax.ApplyIfFiniteState method)
(optax.DifferentiallyPrivateAggregateState method)
(optax.EmaState method)
(optax.EmptyState method)
(optax.experimental.SplitRealAndImaginaryState method)
(optax.FactoredState method)
(optax.GradientTransformation method)
(optax.InjectHyperparamsState method)
(optax.LookaheadParams method)
(optax.LookaheadState method)
(optax.MaskedState method)
(optax.MaybeUpdateState method)
(optax.MultiStepsState method)
(optax.MultiTransformState method)
(optax.ScaleByAdamState method)
(optax.ScaleByAmsgradState method)
(optax.ScaleByLionState method)
(optax.ScaleByNovogradState method)
(optax.ScaleByRmsState method)
(optax.ScaleByRssState method)
(optax.ScaleByRStdDevState method)
(optax.ScaleByScheduleState method)
(optax.ScaleBySM3State method)
(optax.ScaleByTrustRatioState method)
(optax.TraceState method)
(optax.ZeroNansState method)
__init__() (optax.MultiSteps method)
(optax.TransformInitFn method)
(optax.TransformUpdateFn method)
__new__() (optax.EmptyState static method)
(optax.ScaleByTrustRatioState static method)
__subclasshook__() (optax.TransformInitFn method)
(optax.TransformUpdateFn method)
A
acc_grads (optax.MultiStepsState attribute)
adabelief() (in module optax)
adafactor() (in module optax)
adagrad() (in module optax)
adam() (in module optax)
adamax() (in module optax)
adamaxw() (in module optax)
adamw() (in module optax)
adaptive_grad_clip() (in module optax)
AdaptiveGradClipState (in module optax)
add_decayed_weights() (in module optax)
add_noise() (in module optax)
AddDecayedWeightsState (in module optax)
additive_weight_decay() (in module optax)
AdditiveWeightDecayState (in module optax)
AddNoiseState (class in optax)
amsgrad() (in module optax)
apply_every() (in module optax)
apply_if_finite() (in module optax)
apply_updates() (in module optax)
ApplyEvery (class in optax)
ApplyIfFiniteState (class in optax)
C
centralize() (in module optax)
chain() (in module optax)
clip() (in module optax)
clip_by_block_rms() (in module optax)
clip_by_global_norm() (in module optax)
ClipByGlobalNormState (in module optax)
ClipState (in module optax)
constant_schedule() (in module optax)
control_delta_method() (in module optax)
control_variates_jacobians() (in module optax)
convex_kl_divergence() (in module optax)
cosine_decay_schedule() (in module optax)
cosine_distance() (in module optax)
cosine_onecycle_schedule() (in module optax)
cosine_similarity() (in module optax)
count (optax.AddNoiseState attribute)
(optax.ApplyEvery attribute)
(optax.EmaState attribute)
(optax.FactoredState attribute)
(optax.InjectHyperparamsState attribute)
(optax.ScaleByAdamState attribute)
(optax.ScaleByAmsgradState attribute)
(optax.ScaleByLionState attribute)
(optax.ScaleByNovogradState attribute)
(optax.ScaleByScheduleState attribute)
ctc_loss() (in module optax)
ctc_loss_with_forward_probs() (in module optax)
D
differentially_private_aggregate() (in module optax)
DifferentiallyPrivateAggregateState (class in optax)
dpsgd() (in module optax)
E
ema (optax.EmaState attribute)
ema() (in module optax)
EmaState (class in optax)
EmptyState (class in optax)
exponential_decay() (in module optax)
F
FactoredState (class in optax)
fast (optax.LookaheadParams attribute)
,
[1]
fast_state (optax.LookaheadState attribute)
,
[1]
fisher_diag() (in module optax)
flatten() (in module optax)
found_nan (optax.ZeroNansState attribute)
fromage() (in module optax)
G
global_norm() (in module optax)
grad_acc (optax.ApplyEvery attribute)
gradient_step (optax.MultiStepsState attribute)
GradientTransformation (class in optax)
H
hessian_diag() (in module optax)
hinge_loss() (in module optax)
huber_loss() (in module optax)
hvp() (in module optax)
hyperparams (optax.InjectHyperparamsState attribute)
I
identity() (in module optax)
incremental_update() (in module optax)
init (optax.GradientTransformation attribute)
,
[1]
init() (optax.MultiSteps method)
init_synced() (optax.LookaheadParams class method)
inject_hyperparams() (in module optax)
InjectHyperparamsState (class in optax)
inner_opt_state (optax.MultiStepsState attribute)
inner_state (optax.ApplyIfFiniteState attribute)
(optax.experimental.SplitRealAndImaginaryState attribute)
(optax.InjectHyperparamsState attribute)
(optax.MaskedState attribute)
(optax.MaybeUpdateState attribute)
inner_states (optax.MultiTransformState attribute)
J
join_schedules() (in module optax)
K
keep_params_nonnegative() (in module optax)
kl_divergence() (in module optax)
L
l2_loss() (in module optax)
lamb() (in module optax)
lars() (in module optax)
last_finite (optax.ApplyIfFiniteState attribute)
linear_onecycle_schedule() (in module optax)
linear_schedule() (in module optax)
lion() (in module optax)
log_cosh() (in module optax)
lookahead() (in module optax)
LookaheadParams (class in optax)
LookaheadState (class in optax)
M
masked() (in module optax)
MaskedState (class in optax)
matrix_inverse_pth_root() (in module optax)
maybe_update() (in module optax)
MaybeUpdateState (class in optax)
measure_valued_jacobians() (in module optax)
mini_step (optax.MultiStepsState attribute)
moving_avg_baseline() (in module optax)
mu (optax.ScaleByAdamState attribute)
(optax.ScaleByAmsgradState attribute)
(optax.ScaleByLionState attribute)
(optax.ScaleByNovogradState attribute)
(optax.ScaleByRStdDevState attribute)
(optax.ScaleBySM3State attribute)
multi_normal() (in module optax)
,
[1]
multi_transform() (in module optax)
MultiSteps (class in optax)
MultiStepsState (class in optax)
MultiTransformState (class in optax)
N
noisy_sgd() (in module optax)
NonNegativeParamsState (in module optax)
notfinite_count (optax.ApplyIfFiniteState attribute)
novograd() (in module optax)
nu (optax.ScaleByAdamState attribute)
(optax.ScaleByAmsgradState attribute)
(optax.ScaleByNovogradState attribute)
(optax.ScaleByRmsState attribute)
(optax.ScaleByRStdDevState attribute)
(optax.ScaleBySM3State attribute)
nu_max (optax.ScaleByAmsgradState attribute)
O
optimistic_gradient_descent() (in module optax)
OptState (in module optax)
P
Params (in module optax)
pathwise_jacobians() (in module optax)
periodic_update() (in module optax)
piecewise_constant_schedule() (in module optax)
piecewise_interpolate_schedule() (in module optax)
polynomial_schedule() (in module optax)
power_iteration() (in module optax)
R
radam() (in module optax)
rmsprop() (in module optax)
rng_key (optax.AddNoiseState attribute)
(optax.DifferentiallyPrivateAggregateState attribute)
S
safe_int32_increment() (in module optax)
safe_norm() (in module optax)
safe_root_mean_squares() (in module optax)
scale() (in module optax)
scale_by_adam() (in module optax)
scale_by_adamax() (in module optax)
scale_by_amsgrad() (in module optax)
scale_by_belief() (in module optax)
scale_by_factored_rms() (in module optax)
scale_by_lion() (in module optax)
scale_by_novograd() (in module optax)
scale_by_param_block_norm() (in module optax)
scale_by_param_block_rms() (in module optax)
scale_by_radam() (in module optax)
scale_by_rms() (in module optax)
scale_by_rss() (in module optax)
scale_by_schedule() (in module optax)
scale_by_sm3() (in module optax)
scale_by_stddev() (in module optax)
scale_by_trust_ratio() (in module optax)
scale_by_yogi() (in module optax)
scale_gradient() (in module optax)
ScaleByAdamState (class in optax)
ScaleByAmsgradState (class in optax)
ScaleByLionState (class in optax)
ScaleByNovogradState (class in optax)
ScaleByRmsState (class in optax)
ScaleByRssState (class in optax)
ScaleByRStdDevState (class in optax)
ScaleByScheduleState (class in optax)
ScaleBySM3State (class in optax)
ScaleByTrustRatioState (class in optax)
ScaleState (in module optax)
Schedule (in module optax)
score_function_jacobians() (in module optax)
set_to_zero() (in module optax)
sgd() (in module optax)
sgdr_schedule() (in module optax)
sigmoid_binary_cross_entropy() (in module optax)
skip_state (optax.MultiStepsState attribute)
slow (optax.LookaheadParams attribute)
,
[1]
sm3() (in module optax)
smooth_labels() (in module optax)
softmax_cross_entropy() (in module optax)
softmax_cross_entropy_with_integer_labels() (in module optax)
split_real_and_imaginary() (in module optax.experimental)
SplitRealAndImaginaryState (class in optax.experimental)
squared_error() (in module optax)
stateless() (in module optax)
stateless_with_tree_map() (in module optax)
step (optax.MaybeUpdateState attribute)
steps_since_sync (optax.LookaheadState attribute)
,
[1]
sum_of_squares (optax.ScaleByRssState attribute)
T
total_notfinite (optax.ApplyIfFiniteState attribute)
trace (optax.TraceState attribute)
trace() (in module optax)
TraceState (class in optax)
TransformInitFn (class in optax)
TransformUpdateFn (class in optax)
U
update (optax.GradientTransformation attribute)
,
[1]
update() (optax.MultiSteps method)
Updates (in module optax)
V
v (optax.FactoredState attribute)
v_col (optax.FactoredState attribute)
v_row (optax.FactoredState attribute)
W
warmup_cosine_decay_schedule() (in module optax)
warmup_exponential_decay_schedule() (in module optax)
Y
yogi() (in module optax)
Z
zero_nans() (in module optax)
ZeroNansState (class in optax)