JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

  • JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.
  • JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.
  • Run anywhere: The same code executes on multiple backends, including CPU, GPU, & TPU

XLA

The output of JAX is XLA, which is a machine learning compiler that predates MLIR (and is gradually being ported over to MLIR, I believe). Jax inherits the limitations of both Python (e.g the language has no way of representing structs, or allocating memory directly, or creating fast loops) and XLA (which is largely limited to machine learning specific concepts and is primarily targeted to TPUs), but has the huge upside that it doesn’t require a new language or new compiler.