optax.scale_by_learning_rate#
- optax.scale_by_learning_rate(learning_rate: base.ScalarOrSchedule | None = None, *, flip_sign: bool = True) base.GradientTransformation[source]#
Scale by the (negative) learning rate (either as scalar or as schedule).
- Parameters:
learning_rate โ Can either be a scalar or a schedule (i.e. a callable that maps an (int) step to a float). None means no scaling.
flip_sign โ When set to True (the default) this corresponds to scaling by the negative learning rate.
- Returns:
An optax.GradientTransformation that corresponds to multiplying the gradient with -learning_rate (if flip_sign is True) or with learning_rate (if flip_sign is False).