Optax is a gradient processing and optimization library for JAX. It is designed to facilitate research by providing building blocks that can be recombined in custom ways in order to optimise parametric models such as, but not limited to, deep neural networks.
Our goals are to
Provide readable, well-tested, efficient implementations of core components,
Improve researcher productivity by making it possible to combine low level ingredients into custom optimiser (or other gradient processing components).
Accelerate adoption of new ideas by making it easy for anyone to contribute.
We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.
The latest release of Optax can be installed from PyPI using
pip install optax
You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Optax.
pip install git+git://github.com/deepmind/optax.git
Note that Optax is built on top of JAX. See here for instructions on installing JAX.
- Common Optimizers
- Optax Transformations
- Apply Updates
- Combining Optimizers
- Optimizer Wrappers
- Common Losses
- Linear Algebra Operators
- Utilities for numerical stability
- Optimizer Schedules
- Second Order Optimization Utilities
- Control Variates
- Stochastic Gradient Estimators
- Privacy-Sensitive Optax Methods
- General Utilities
- 🔧 Contrib
- 🚧 Experimental
The development of Optax is led by Ross Hemsley, Matteo Hessel, Markus Kunesch and Iurii Kemaev. The team relies on outstanding contributions from Research Engineers and Research Scientists from throughout DeepMind and Alphabet. We are also very grateful to Optax’s open source community for contributing ideas, bug fixes, issues, design docs, and amazing new features.
The work on Optax is part of a wider effort to contribute to making the JAX Ecosystem the best possible environment for ML/AI research.