optax.add_decayed_weights#
- optax.add_decayed_weights(weight_decay: base.ScalarOrSchedule = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#
Add parameter scaled by weight_decay.
- Parameters:
weight_decay โ A scalar weight decay rate.
mask โ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Returns:
A
optax.GradientTransformationobject.