optax.apply_every

Contents

optax.apply_every#

optax.apply_every(k: jax.typing.ArrayLike = 1) optax.GradientTransformation[source]#

Accumulate gradients and apply them every k steps.

Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using apply_every with a batch size of N/2 and k=2 is not necessarily equivalent to not using apply_every with a batch size of N. If this equivalence is important for you, consider using the optax.MultiSteps.

Parameters:

k โ€“ Emit non-zero gradients every k steps, otherwise accumulate them.

Returns:

A optax.GradientTransformation object.