optax.ConditionallyMaskState

optax.ConditionallyMaskState#

class optax.ConditionallyMaskState(step, inner_state)[source]#