optax.FactoredState

Contents

optax.FactoredState#

class optax.FactoredState(count: jax.typing.ArrayLike, v_row: optax.ArrayTree, v_col: optax.ArrayTree, v: optax.ArrayTree)[source]#

Overall state of the gradient transformation.