optax.scale_by_rss#
- optax.scale_by_rss(initial_accumulator_value: jax.typing.ArrayLike = 0.1, eps: jax.typing.ArrayLike = 1e-07) optax.GradientTransformation[source]#
Rescale updates by the root of the sum of all squared gradients to date.
See
optax.adagrad()for more details.- Parameters:
initial_accumulator_value โ Starting value for accumulators, must be >= 0.
eps โ A small floating point value to avoid zero denominator.
- Returns:
A
optax.GradientTransformationobject.