optax.add_decayed_weights

optax.add_decayed_weights#

optax.add_decayed_weights(weight_decay: base.ScalarOrSchedule = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#

Add parameter scaled by weight_decay.

Parameters:
  • weight_decay โ€“ A scalar weight decay rate.

  • mask โ€“ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Returns:

A optax.GradientTransformation object.