optax.contrib.SAMState#
- class optax.contrib.SAMState(steps_since_sync: jax.Array, opt_state: base.OptState, adv_state: base.OptState, cache: base.Params | None)[source]#
State of GradientTransformation returned by sam.
- opt_state#
State of the outer optimizer.
- Type:
base.OptState
- adv_state#
State of the inner adversarial optimizer.
- Type:
base.OptState
- cache#
a place to store the last outer updates.
- Type:
Optional[base.Params]