How DeepMind Is Using JAX To Accelerate AI Research


JAX is a Python library developed by Google researchers for high-performance numerical computing. Its API is based on NumPy. NumPy is a collection of functions applied in scientific computing. Developers extensively adopt Python and NumPy, making JAX simple, flexible, and easy to use. JAX and its developing ecosystem of open source libraries have assisted and accelerated numerous machine learning projects.

JAX at DeepMind

While supporting AI research, it is essential to ensure that the AI experiments are scalable to the real-world application. Advancing AI research needs balancing rapid prototyping and quick iteration. Researchers at DeepMind feature approaches to enable the core JAX libraries to continue with new research directions.

One approach is to extract the most essential and crucial building blocks developed in each research project into well tested and efficient components. This enables researchers to benefit from code reuse, bug fixes, and performance improvements in core libraries’ algorithmic ingredients.

Other considerations include making sure that it remains consistent with the design of existing TensorFlow libraries. Each library has a specified scope and is completely understood to work with different projects.

DeepMind’s JAX Ecosystem

DeepMind’s open-sourced ecosystem of JAX libraries includes several libraries to support machine learning research. It includes:   

  • Haiku for neural network modules: A neural network library for TensorFlow, Haiku is a neural network library that makes managing model parameters and other model states simpler. It allows users to use simple object-oriented programming models while harnessing JAX’s pure functional paradigm’s power and simplicity. It has been adopted in external projects like CoaxDeepChemNumPyro.
  • Optax for gradient processing and optimization: Optax is a gradient transformation library that includes composition operators, allowing users to implement many standard optimizers in a single code line. Use Cases: ElegyFlax, and Stax
  • RLax for RL algorithms: RLax is a library that gives building blocks essential for constructing reinforcement learning (RL) agents. Its components include a broad spectrum of algorithms and ideas such as TD-learning, actor critics, policy gradients, MAP, proximal policy optimization, non-linear value transformation, general value functions, and many exploration techniques. Acme is an example of a fully-featured agent framework built upon RLax components.
  • Chex for reliable code and testing: Chex is a library of utilities designed to help users write reliable JAX code. It is a collection of testing utilities used by library authors to verify that the standard building blocks are correct and robust and end-users to examine test code. This library is used throughout DeepMind and has been employed in external projects like Coax and MineRL.
  • Jraph for neural graph networks: Jraph supports working with GNNs in JAX. It is a lightweight library that provides a standardized data structure for graphs, a set of utility graphs, and a ‘zoo’ of forkable graph neural network models. Other essential features include: batching of GraphTuples that leverages hardware accelerators, JIT-compilation support of variable-shaped graphs by masking and padding, and losses defined over input partitions. Like other libraries, Jraph places no constraints on the user’s choice of a neural network library.





Please enter your comment!
Please enter your name here

This site uses Akismet to reduce spam. Learn how your comment data is processed.