optax.contrib.sam#
- optax.contrib.sam(optimizer: base.GradientTransformation, adv_optimizer: base.GradientTransformation, sync_period: int = 2, reset_state: bool = True, opaque_mode: bool = False, batch_axis_name: str | None = None) base.GradientTransformationExtraArgs[source]#
Implementation of SAM (Sharpness Aware Minimization).
Performs steps with the inner adversarial optimizer and periodically updates an outer set of true parameters. By default, resets the state of the adversarial optimizer after synchronization. For example:
>>> import optax >>> rho = 0.1 >>> opt = optax.sgd(learning_rate=0.01) >>> adv_opt = optax.chain(optax.contrib.normalize(), optax.sgd(rho)) >>> sam_opt = optax.contrib.sam(opt, adv_opt, sync_period=2)
Would implement the simple drop-in SAM version from the paper which uses an inner adversarial optimizer of a normalized sgd for one step.
- Parameters:
optimizer โ the outer optimizer.
adv_optimizer โ the inner adversarial optimizer.
sync_period โ how often to run the outer optimizer, defaults to 2, or every other step.
reset_state โ whether to reset the state of the inner optimizer after every sync period, defaults to
True.opaque_mode โ bool. If
True, the outer optimizer and the adversarial optimizer are run in an internal loop at each call toupdate, so that adversarial updates are opaque to the rest of the system. IfFalse, one optimizer is (effectively) evaluated per call toupdate, meaning that adversarial updates are visible to the rest of the system. Settingopaque_modetoTrueis necessary if the training system using SAM has side effects from each call toupdatebesides the changes to the modelโs parameters. The most common example would be if the model uses BatchNorm statistics โ those statistics would be updated on both adversarial and non-adversarial update steps, causing them to get out of sync with the modelโs parameters (which are effectively only updated on non-adversarial steps). See the NOTE section for more details onopaque_mode=True.batch_axis_name โ str or None. Only used if
opaque_mode=True. When running in a pmapped setting, it is necessary to take ajax.lax.pmeanof the adversarial updates internally before passing them to the outer optimizer. You only need to specify this if you have to usejax.lax.pmeanin your training loop.
- Returns:
The corresponding
optax.GradientTransformationExtraArgsimplementation of SAM.
References
Foret et al., Sharpness-Aware Minimization for Efficiently Improving Generalization, 2021
Note
When
opaque_mode=True, theupdatefunction must be called with a gradient function that takes two arguments (the params and the current adversarial step) and returns the gradients of the loss. This looks like the following:opt = sam(outer_opt, adv_opt, opaque_mode=True) ... # In the training loop: grad_fn = jax.grad( lambda params, _: loss(params, batch, and_other_args)) updates, state = opt.update(updates, state, params, grad_fn=grad_fn) params = optax.apply_updates(params, updates)
On every call to
opt.update,grad_fnwill be calledsync_period - 1times, once for each adversarial update. It is usually ok to use the same minibatch in each of those updates, as in the example above, but you can use the second argument to select different batches at each adversarial step:grad_fn = jax.grad(lambda params, i: loss(params, batches[i]))