optax.apply_every#
- optax.apply_every(k: jax.typing.ArrayLike = 1) optax.GradientTransformation[source]#
Accumulate gradients and apply them every k steps.
Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using apply_every with a batch size of N/2 and k=2 is not necessarily equivalent to not using apply_every with a batch size of N. If this equivalence is important for you, consider using the optax.MultiSteps.
- Parameters:
k โ Emit non-zero gradients every k steps, otherwise accumulate them.
- Returns:
A
optax.GradientTransformationobject.