optax.ConditionallyTransformState# class optax.ConditionallyTransformState(inner_state: Any, step: jax.typing.ArrayLike)[source]# Maintains inner transform state and adds a step counter.