Optax is a gradient processing and optimization library for JAX. It is designed to facilitate research by providing building blocks that can be recombined in custom ways in order to optimise parametric models such as, but not limited to, deep neural networks.

Our goals are to

  • Provide readable, well-tested, efficient implementations of core components,

  • Improve researcher productivity by making it possible to combine low level ingredients into custom optimiser (or other gradient processing components).

  • Accelerate adoption of new ideas by making it easy for anyone to contribute.

We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.

The Team

The development of Optax is led by Ross Hemsley, Matteo Hessel, Markus Kunesch and Iurii Kemaev. The team relies on outstanding contributions from Research Engineers and Research Scientists from throughout [DeepMind](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt) and Alphabet. We are also very grateful to Optax’s open source community for contributing ideas, bug fixes, issues, design docs, and amazing new features.

The work on Optax is part of a wider effort to contribute to making the [JAX Ecosystem](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt) the best possible environment for ML/AI research.


If you are having issues, please let us know by filing an issue on our issue tracker.


Optax is licensed under the Apache 2.0 License.

Indices and Tables