Transformations

Transformations#

adaptive_grad_clip(clipping[, eps, axis])

Clips updates to be at most clipping * parameter_norm, unit-wise.

AdaptiveGradClipState

add_decayed_weights([weight_decay, mask])

Add parameter scaled by weight_decay.

AddDecayedWeightsState

add_noise(eta, gamma[, key, seed])

Add gradient noise.

AddNoiseState(count, rng_key)

State for adding gradient noise.

apply_every([k])

Accumulate gradients and apply them every k steps.

ApplyEvery(count, grad_acc)

Contains a counter and a gradient accumulator.

bias_correction(moment, decay, count)

Performs bias correction.

conditionally_mask(inner, should_transform_fn)

Calls the inner update function only at certain steps.

conditionally_transform(inner, ...[, ...])

Calls the inner update function only at certain steps.

ConditionallyMaskState(step, inner_state)

ConditionallyTransformState(inner_state, step)

Maintains inner transform state and adds a step counter.

centralize()

Centralizes gradients by subtracting their mean along leading dimension.

clip(max_delta)

Clips updates element-wise, to be in [-max_delta, +max_delta].

clip_by_block_rms(threshold)

Clips updates to a max rms for the gradient of each param vector or matrix.

ClipState

clip_by_global_norm(max_norm)

Clips updates using their global norm.

ClipByGlobalNormState

ema(decay[, debias, accumulator_dtype])

Compute an exponential moving average of past updates.

EmaState(count, ema)

Holds an exponential moving average of past updates.

EmptyState()

An empty state for the simplest stateless transformations.

global_norm(updates)

Compute the global norm across a nested structure of tensors.

GradientTransformation(init, update)

A pair of pure functions implementing a gradient transformation.

GradientTransformationExtraArgs(init, update)

A specialization of GradientTransformation that supports extra args.

identity()

Stateless identity transformation that leaves input gradients untouched.

init_empty_state(params)

Init function for a GradientTransformation with empty state.

keep_params_nonnegative()

Modifies the updates to keep parameters non-negative, i.e. >= 0.

measure_with_ema(measure, decay[, debias, ...])

Take a measurement and record it with exponential moving average.

monitor(measures)

Monitors stateful measurements of updates in a chain.

MonitorState(measurements, measure_states)

NonNegativeParamsState

normalize_by_update_norm([scale_factor, eps])

Scale by the inverse of the update norm.

OptState

Params

per_example_global_norm_clip(grads, l2_norm_clip)

Applies gradient clipping per-example using their global norm.

per_example_layer_norm_clip(grads, ...[, ...])

Applies gradient clipping per-example using per-layer norms.

scale(step_size)

Scale updates by some fixed scalar step_size.

ScaleState

scale_by_adadelta([rho, eps])

Rescale updates according to the Adadelta algorithm.

ScaleByAdaDeltaState(e_g, e_x)

State for the rescaling by Adadelta algorithm.

scale_by_adan([b1, b2, b3, eps, eps_root])

Rescale updates according to the Adan algorithm.

ScaleByAdanState(m, v, n, g, t)

scale_by_adam([b1, b2, eps, eps_root, ...])

Rescale updates according to the Adam algorithm.

scale_by_adamax([b1, b2, eps])

Rescale updates according to the Adamax algorithm.

ScaleByAdamState(count, mu, nu)

State for the Adam algorithm.

scale_by_amsgrad([b1, b2, eps, eps_root, ...])

Rescale updates according to the AMSGrad algorithm.

ScaleByAmsgradState(count, mu, nu, nu_max)

State for the AMSGrad algorithm.

scale_by_backtracking_linesearch(...[, ...])

Backtracking line-search ensuring sufficient decrease (Armijo criterion).

ScaleByBacktrackingLinesearchState(...)

State for optax.scale_by_backtracking_linesearch().

scale_by_belief([b1, b2, eps, eps_root, ...])

Rescale updates according to the AdaBelief algorithm.

ScaleByBeliefState(count, mu, nu)

State for the rescaling by AdaBelief algorithm.

scale_by_factored_rms(factored, decay_rate, ...)

Scaling by a factored estimate of the gradient rms (as in Adafactor).

FactoredState(count, v_row, v_col, v)

Overall state of the gradient transformation.

scale_by_lbfgs([memory_size, scale_init_precond])

Scales updates by L-BFGS.

ScaleByLBFGSState(count, params, updates, ...)

State for LBFGS solver.

scale_by_learning_rate([learning_rate, ...])

Scale by the (negative) learning rate (either as scalar or as schedule).

scale_by_lion([b1, b2, mu_dtype, mode, ...])

Rescale updates according to the Lion algorithm.

ScaleByLionState(count, mu)

State for the Lion algorithm.

scale_by_novograd([b1, b2, eps, eps_root, ...])

Computes NovoGrad updates.

ScaleByNovogradState(count, mu, nu)

State for Novograd.

scale_by_optimistic_gradient([alpha, beta])

Compute generalized optimistic gradients.

scale_by_param_block_norm([min_scale])

Scale updates for each param block by the norm of that block's parameters.

scale_by_param_block_rms([min_scale])

Scale updates by rms of the gradient for each param vector or matrix.

scale_by_polyak([f_min, max_learning_rate, ...])

Scales the update by Polyak's step-size.

scale_by_radam([b1, b2, eps, eps_root, ...])

Rescale updates according to the Rectified Adam algorithm.

scale_by_rms([decay, eps, initial_scale, ...])

Rescale updates by the root of the exp.

ScaleByRmsState(nu)

State for exponential root mean-squared (RMS)-normalized updates.

scale_by_rprop(learning_rate[, eta_minus, ...])

Scale with the Rprop optimizer.

ScaleByRpropState(step_sizes, prev_updates)

scale_by_rss([initial_accumulator_value, eps])

Rescale updates by the root of the sum of all squared gradients to date.

ScaleByRssState(sum_of_squares)

State holding the sum of gradient squares to date.

scale_by_schedule(step_size_fn)

Scale updates using a custom schedule for the step_size.

ScaleByScheduleState(count)

Maintains count for scale scheduling.

scale_by_sign()

Compute the signs of the gradient elements.

scale_by_sm3([b1, b2, eps])

Scale updates by sm3.

ScaleBySM3State(mu, nu)

State for the SM3 algorithm.

scale_by_stddev([decay, eps, initial_scale, ...])

Rescale updates by the root of the centered exp.

ScaleByRStdDevState(mu, nu)

State for centered exponential moving average of squares of updates.

scale_by_trust_ratio([min_norm, ...])

Scale updates by trust ratio.

ScaleByTrustRatioState

scale_by_yogi([b1, b2, eps, eps_root, ...])

Rescale updates according to the Yogi algorithm.

scale_by_zoom_linesearch(max_linesearch_steps)

Linesearch ensuring sufficient decrease and small curvature.

ScaleByZoomLinesearchState(learning_rate, ...)

State for scale_by_zoom_linesearch.

set_to_zero()

Stateless transformation that maps input gradients to zero.

snapshot(measure_name, measure)

Takes a snapshot of updates and stores it in the state.

SnapshotState(measurement)

stateless(f)

Creates a stateless transformation from an update-like function.

stateless_with_tree_map(f)

Creates a stateless transformation from an update-like function for arrays.

trace(decay[, nesterov, accumulator_dtype])

Compute a trace of past updates.

TraceState(trace)

Holds an aggregation of past updates.

TransformInitFn(*args, **kwargs)

A callable type for the init step of a GradientTransformation.

TransformUpdateFn(*args, **kwargs)

A callable type for the update step of a GradientTransformation.

TransformUpdateExtraArgsFn(*args, **kwargs)

An update function accepting additional keyword arguments.

update_infinity_moment(updates, moments, ...)

Compute the exponential moving average of the infinity norm.

update_moment(updates, moments, decay, order)

Compute the exponential moving average of the order-th moment.

update_moment_per_elem_norm(updates, ...)

Compute the EMA of the order-th moment of the element-wise norm.

Updates

with_extra_args_support(tx)

Wraps a gradient transformation, so that it ignores extra args.

zero_nans()

A transformation which replaces NaNs with 0.

ZeroNansState(found_nan)

Contains a tree.

ZoomLinesearchInfo(num_linesearch_steps, ...)

Information about the zoom linesearch step, exposed for debugging.

masked(inner, mask, *[, ...])

Mask updates so only some are transformed, the rest are passed through.

freeze(mask)

Create a transformation that zeros out gradient updates for mask=True.

selective_transform(optimizer, *, freeze_mask)

Partition updates so that only un-frozen parameters are optimized.