optax.scale_by_learning_rate

optax.scale_by_learning_rate#

optax.scale_by_learning_rate(learning_rate: base.ScalarOrSchedule | None = None, *, flip_sign: bool = True) base.GradientTransformation[source]#

Scale by the (negative) learning rate (either as scalar or as schedule).

Parameters:
  • learning_rate โ€“ Can either be a scalar or a schedule (i.e. a callable that maps an (int) step to a float). None means no scaling.

  • flip_sign โ€“ When set to True (the default) this corresponds to scaling by the negative learning rate.

Returns:

An optax.GradientTransformation that corresponds to multiplying the gradient with -learning_rate (if flip_sign is True) or with learning_rate (if flip_sign is False).