optax.contrib.GaLoreState#

class optax.contrib.GaLoreState(count: jax.typing.ArrayLike, base_optimizer_state: optax.OptState, projector: optax.Updates)[source]#

State for the GaLore optimizer.

count#

Number of update steps taken.

Type:

jax.typing.ArrayLike

base_optimizer_state#

State for the base optimizer, operating on low-rank gradients for 2D params and full gradients for non-2D params.

Type:

base.OptState

projector#

Projection matrices for each 2D parameter.

Type:

base.Updates