optax.contrib.sam

Contents

optax.contrib.sam#

optax.contrib.sam(optimizer: base.GradientTransformation, adv_optimizer: base.GradientTransformation, sync_period: int = 2, reset_state: bool = True, opaque_mode: bool = False, batch_axis_name: str | None = None) base.GradientTransformationExtraArgs[source]#

Implementation of SAM (Sharpness Aware Minimization).

Performs steps with the inner adversarial optimizer and periodically updates an outer set of true parameters. By default, resets the state of the adversarial optimizer after synchronization. For example:

>>> import optax
>>> rho = 0.1
>>> opt = optax.sgd(learning_rate=0.01)
>>> adv_opt = optax.chain(optax.contrib.normalize(), optax.sgd(rho))
>>> sam_opt = optax.contrib.sam(opt, adv_opt, sync_period=2)

Would implement the simple drop-in SAM version from the paper which uses an inner adversarial optimizer of a normalized sgd for one step.

Parameters:
  • optimizer โ€“ the outer optimizer.

  • adv_optimizer โ€“ the inner adversarial optimizer.

  • sync_period โ€“ how often to run the outer optimizer, defaults to 2, or every other step.

  • reset_state โ€“ whether to reset the state of the inner optimizer after every sync period, defaults to True.

  • opaque_mode โ€“ bool. If True, the outer optimizer and the adversarial optimizer are run in an internal loop at each call to update, so that adversarial updates are opaque to the rest of the system. If False, one optimizer is (effectively) evaluated per call to update, meaning that adversarial updates are visible to the rest of the system. Setting opaque_mode to True is necessary if the training system using SAM has side effects from each call to update besides the changes to the modelโ€™s parameters. The most common example would be if the model uses BatchNorm statistics โ€“ those statistics would be updated on both adversarial and non-adversarial update steps, causing them to get out of sync with the modelโ€™s parameters (which are effectively only updated on non-adversarial steps). See the NOTE section for more details on opaque_mode=True.

  • batch_axis_name โ€“ str or None. Only used if opaque_mode=True. When running in a pmapped setting, it is necessary to take a jax.lax.pmean of the adversarial updates internally before passing them to the outer optimizer. You only need to specify this if you have to use jax.lax.pmean in your training loop.

Returns:

The corresponding optax.GradientTransformationExtraArgs implementation of SAM.

References

Foret et al., Sharpness-Aware Minimization for Efficiently Improving Generalization, 2021

Note

When opaque_mode=True, the update function must be called with a gradient function that takes two arguments (the params and the current adversarial step) and returns the gradients of the loss. This looks like the following:

opt = sam(outer_opt, adv_opt, opaque_mode=True)
...
# In the training loop:
grad_fn = jax.grad(
  lambda params, _: loss(params, batch, and_other_args))
updates, state = opt.update(updates, state, params, grad_fn=grad_fn)
params = optax.apply_updates(params, updates)

On every call to opt.update, grad_fn will be called sync_period - 1 times, once for each adversarial update. It is usually ok to use the same minibatch in each of those updates, as in the example above, but you can use the second argument to select different batches at each adversarial step:

grad_fn = jax.grad(lambda params, i: loss(params, batches[i]))