optax.scale_by_rss

Contents

optax.scale_by_rss#

optax.scale_by_rss(initial_accumulator_value: jax.typing.ArrayLike = 0.1, eps: jax.typing.ArrayLike = 1e-07) optax.GradientTransformation[source]#

Rescale updates by the root of the sum of all squared gradients to date.

See optax.adagrad() for more details.

Parameters:
  • initial_accumulator_value โ€“ Starting value for accumulators, must be >= 0.

  • eps โ€“ A small floating point value to avoid zero denominator.

Returns:

A optax.GradientTransformation object.