optax.incremental_update#
- optax.incremental_update(new_tensors: optax.Params, old_tensors: optax.Params, step_size: jax.typing.ArrayLike) optax.Params[source]#
Incrementally update parameters via polyak averaging.
Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time.
- Parameters:
new_tensors โ the latest value of the tensors.
old_tensors โ a moving average of the values of the tensors.
step_size โ the step_size used to update the polyak average on each step.
- Returns:
an updated moving average step_size*new+(1-step_size)*old of the params.
References
[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)