optax.incremental_update

optax.incremental_update#

optax.incremental_update(new_tensors: optax.Params, old_tensors: optax.Params, step_size: jax.typing.ArrayLike) optax.Params[source]#

Incrementally update parameters via polyak averaging.

Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time.

Parameters:
  • new_tensors โ€“ the latest value of the tensors.

  • old_tensors โ€“ a moving average of the values of the tensors.

  • step_size โ€“ the step_size used to update the polyak average on each step.

Returns:

an updated moving average step_size*new+(1-step_size)*old of the params.

References

[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)