optax.contrib.GaLoreState#
- class optax.contrib.GaLoreState(count: jax.typing.ArrayLike, base_optimizer_state: optax.OptState, projector: optax.Updates)[source]#
State for the GaLore optimizer.
- count#
Number of update steps taken.
- Type:
jax.typing.ArrayLike
- base_optimizer_state#
State for the base optimizer, operating on low-rank gradients for 2D params and full gradients for non-2D params.
- Type:
base.OptState
- projector#
Projection matrices for each 2D parameter.
- Type:
base.Updates