optax.ConditionallyTransformState

optax.ConditionallyTransformState#

class optax.ConditionallyTransformState(inner_state: Any, step: jax.typing.ArrayLike)[source]#

Maintains inner transform state and adds a step counter.