Transformations#
|
Clips updates to be at most |
|
Add parameter scaled by weight_decay. |
|
Add gradient noise. |
|
State for adding gradient noise. |
|
Accumulate gradients and apply them every k steps. |
|
Contains a counter and a gradient accumulator. |
|
Performs bias correction. |
|
Calls the inner update function only at certain steps. |
|
Calls the inner update function only at certain steps. |
|
|
|
Maintains inner transform state and adds a step counter. |
Centralizes gradients by subtracting their mean along leading dimension. |
|
|
Clips updates element-wise, to be in |
|
Clips updates to a max rms for the gradient of each param vector or matrix. |
|
Clips updates using their global norm. |
|
Compute an exponential moving average of past updates. |
|
Holds an exponential moving average of past updates. |
An empty state for the simplest stateless transformations. |
|
|
Compute the global norm across a nested structure of tensors. |
|
A pair of pure functions implementing a gradient transformation. |
|
A specialization of GradientTransformation that supports extra args. |
|
Stateless identity transformation that leaves input gradients untouched. |
|
Init function for a |
Modifies the updates to keep parameters non-negative, i.e. >= 0. |
|
|
Take a measurement and record it with exponential moving average. |
|
Monitors stateful measurements of updates in a chain. |
|
|
|
Scale by the inverse of the update norm. |
|
Applies gradient clipping per-example using their global norm. |
|
Applies gradient clipping per-example using per-layer norms. |
|
Scale updates by some fixed scalar step_size. |
|
Rescale updates according to the Adadelta algorithm. |
|
State for the rescaling by Adadelta algorithm. |
|
Rescale updates according to the Adan algorithm. |
|
|
|
Rescale updates according to the Adam algorithm. |
|
Rescale updates according to the Adamax algorithm. |
|
State for the Adam algorithm. |
|
Rescale updates according to the AMSGrad algorithm. |
|
State for the AMSGrad algorithm. |
|
Backtracking line-search ensuring sufficient decrease (Armijo criterion). |
State for |
|
|
Rescale updates according to the AdaBelief algorithm. |
|
State for the rescaling by AdaBelief algorithm. |
|
Scaling by a factored estimate of the gradient rms (as in Adafactor). |
|
Overall state of the gradient transformation. |
|
Scales updates by L-BFGS. |
|
State for LBFGS solver. |
|
Scale by the (negative) learning rate (either as scalar or as schedule). |
|
Rescale updates according to the Lion algorithm. |
|
State for the Lion algorithm. |
|
Computes NovoGrad updates. |
|
State for Novograd. |
|
Compute generalized optimistic gradients. |
|
Scale updates for each param block by the norm of that block's parameters. |
|
Scale updates by rms of the gradient for each param vector or matrix. |
|
Scales the update by Polyak's step-size. |
|
Rescale updates according to the Rectified Adam algorithm. |
|
Rescale updates by the root of the exp. |
|
State for exponential root mean-squared (RMS)-normalized updates. |
|
Scale with the Rprop optimizer. |
|
|
|
Rescale updates by the root of the sum of all squared gradients to date. |
|
State holding the sum of gradient squares to date. |
|
Scale updates using a custom schedule for the step_size. |
|
Maintains count for scale scheduling. |
Compute the signs of the gradient elements. |
|
|
Scale updates by sm3. |
|
State for the SM3 algorithm. |
|
Rescale updates by the root of the centered exp. |
|
State for centered exponential moving average of squares of updates. |
|
Scale updates by trust ratio. |
|
Rescale updates according to the Yogi algorithm. |
|
Linesearch ensuring sufficient decrease and small curvature. |
|
State for scale_by_zoom_linesearch. |
Stateless transformation that maps input gradients to zero. |
|
|
Takes a snapshot of updates and stores it in the state. |
|
|
|
Creates a stateless transformation from an update-like function. |
Creates a stateless transformation from an update-like function for arrays. |
|
|
Compute a trace of past updates. |
|
Holds an aggregation of past updates. |
|
A callable type for the init step of a GradientTransformation. |
|
A callable type for the update step of a GradientTransformation. |
|
An update function accepting additional keyword arguments. |
|
Compute the exponential moving average of the infinity norm. |
|
Compute the exponential moving average of the order-th moment. |
|
Compute the EMA of the order-th moment of the element-wise norm. |
Wraps a gradient transformation, so that it ignores extra args. |
|
A transformation which replaces NaNs with 0. |
|
|
Contains a tree. |
|
Information about the zoom linesearch step, exposed for debugging. |
|
Mask updates so only some are transformed, the rest are passed through. |
|
Create a transformation that zeros out gradient updates for mask=True. |
|
Partition updates so that only un-frozen parameters are optimized. |