Apply Updates#
|
Applies an update to the corresponding parameters. |
|
Incrementally update parameters via polyak averaging. |
|
Periodically update all parameters with new values. |
Apply updates#
- optax.apply_updates(params: optax.Params, updates: optax.Updates) optax.Params[source]#
Applies an update to the corresponding parameters.
This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using jax.tree.map (e.g. if you want to manipulate updates in custom ways before applying them).
- Parameters:
params – a tree of parameters.
updates – a tree of updates, the tree structure and the shape of the leaf nodes must match that of params.
- Returns:
Updated parameters, with same structure, shape and type as params.
Incremental update#
- optax.incremental_update(new_tensors: optax.Params, old_tensors: optax.Params, step_size: jax.typing.ArrayLike) optax.Params[source]#
Incrementally update parameters via polyak averaging.
Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time.
- Parameters:
new_tensors – the latest value of the tensors.
old_tensors – a moving average of the values of the tensors.
step_size – the step_size used to update the polyak average on each step.
- Returns:
an updated moving average step_size*new+(1-step_size)*old of the params.
References
[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)
Periodic update#
- optax.periodic_update(new_tensors: optax.Params, old_tensors: optax.Params, steps: jax.typing.ArrayLike, update_period: jax.typing.ArrayLike) optax.Params[source]#
Periodically update all parameters with new values.
A slow copy of a model’s parameters, updated every K actual updates, can be used to implement forms of self-supervision (in supervised learning), or to stabilize temporal difference learning updates (in reinforcement learning).
- Parameters:
new_tensors – the latest value of the tensors.
old_tensors – a slow copy of the model’s parameters.
steps – number of update steps on the “online” network.
update_period – every how many steps to update the “target” network.
- Returns:
a slow copy of the model’s parameters, updated every update_period steps.
References
[Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)