Apply Updates#

apply_updates(params, updates)

Applies an update to the corresponding parameters.

incremental_update(new_tensors, old_tensors, ...)

Incrementally update parameters via polyak averaging.

periodic_update(new_tensors, old_tensors, ...)

Periodically update all parameters with new values.

Apply updates#

optax.apply_updates(params, updates)[source]#

Applies an update to the corresponding parameters.

This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using tree_map (e.g. if you want to manipulate updates in custom ways before applying them).

Parameters:
  • params (optax.Params) – a tree of parameters.

  • updates (optax.Updates) – a tree of updates, the tree structure and the shape of the leaf

  • params. (nodes must match that of)

Return type:

optax.Params

Returns:

Updated parameters, with same structure, shape and type as params.

Incremental update#

optax.incremental_update(new_tensors, old_tensors, step_size)[source]#

Incrementally update parameters via polyak averaging.

Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time.

References

[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)

Parameters:
  • new_tensors (optax.Params) – the latest value of the tensors.

  • old_tensors (optax.Params) – a moving average of the values of the tensors.

  • step_size (chex.Numeric) – the step_size used to update the polyak average on each step.

Return type:

optax.Params

Returns:

an updated moving average step_size*new+(1-step_size)*old of the params.

Periodic update#

optax.periodic_update(new_tensors, old_tensors, steps, update_period)[source]#

Periodically update all parameters with new values.

A slow copy of a model’s parameters, updated every K actual updates, can be used to implement forms of self-supervision (in supervised learning), or to stabilise temporal difference learning updates (in reinforcement learning).

References

[Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)

Parameters:
  • new_tensors (optax.Params) – the latest value of the tensors.

  • old_tensors (optax.Params) – a slow copy of the model’s parameters.

  • steps (chex.Array) – number of update steps on the “online” network.

  • update_period (int) – every how many steps to update the “target” network.

Return type:

optax.Params

Returns:

a slow copy of the model’s parameters, updated every update_period steps.