Optax is a gradient processing and optimization library for JAX. It is designed to facilitate research by providing building blocks that can be recombined in custom ways in order to optimise parametric models such as, but not limited to, deep neural networks.

Our goals are to

  • Provide readable, well-tested, efficient implementations of core components,

  • Improve researcher productivity by making it possible to combine low level ingredients into custom optimiser (or other gradient processing components).

  • Accelerate adoption of new ideas by making it easy for anyone to contribute.

We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.


The latest release of Optax can be installed from PyPI using:

pip install optax

You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Optax:

pip install git+git://github.com/google-deepmind/optax.git

Note that Optax is built on top of JAX. See here for instructions on installing JAX.


If you encounter issues with this software, please let us know by filing an issue on our issue tracker. We are also happy to receive bug fixes and other contributions. For more information of how to contribute, please see the development guide.


Optax is licensed under the Apache 2.0 License.

Indices and Tables#