optax.MaskedState

Contents

optax.MaskedState#

class optax.MaskedState(inner_state: Any)[source]#

Maintains inner transform state for masked transformations.