optax.contrib.SAMState#

class optax.contrib.SAMState(steps_since_sync: jax.Array, opt_state: base.OptState, adv_state: base.OptState, cache: base.Params | None)[source]#

State of GradientTransformation returned by sam.

steps_since_sync#

Number of adversarial steps taken since the last outer update.

Type:

jax.Array

opt_state#

State of the outer optimizer.

Type:

base.OptState

adv_state#

State of the inner adversarial optimizer.

Type:

base.OptState

cache#

a place to store the last outer updates.

Type:

Optional[base.Params]