This guide explains how to use the JAX backend in Meridian.
Introduction to the JAX backend
By default, Meridian uses TensorFlow for its core numerical operations and probabilistic Markov Chain Monte Carlo (MCMC) sampling, providing a robust and thoroughly tested foundation for all modeling tasks.
For projects that benefit from enhanced performance and memory efficiency, Meridian provides the JAX backend. JAX encourages a functional programming style and utilizes XLA (Accelerated Linear Algebra) compilation to offer advanced performance optimizations.
Tutorial: To see JAX in action, see the Getting started with JAX notebook.
How to enable JAX
Because the core mathematical libraries are loaded at initialization, you must instruct Meridian to use JAX before importing any Meridian modules.
To enable JAX, set the MERIDIAN_BACKEND environment variable to 'jax'.
You must set this environment variable in your script before executing any
import meridian statements:
import os
# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load
Enable 64-bit precision
For models where convergence is difficult to achieve, using 64-bit precision
with the JAX backend can provide improved numerical stability. While 64-bit
precision offers better numerical stability, it comes with increased memory
usage and slower computation times. Therefore, 32-bit precision remains the
default for most use cases. To enable it, set the MERIDIAN_ENABLE_JAX_X64
environment variable to 'True' before importing Meridian.
import os
# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'
# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
If an invalid string is provided to the MERIDIAN_BACKEND environment variable,
Meridian will issue a RuntimeWarning and default back to standard TensorFlow
execution. If a value other than 'True' or '1' is provided to the
MERIDIAN_ENABLE_JAX_X64 environment variable, 64-bit precision is not enabled,
and Meridian defaults to 32-bit precision.
Numerical differences and reproducibility
Because TensorFlow and JAX compile their computational graphs differently, you may observe minor numerical differences in your posterior estimates when switching to JAX using the same data and random seeds.
While posterior distributions might not be identical across backends, the differences are generally small and not statistically significant for business metrics such as ROI and budget allocation. This ensures that switching to the JAX backend maintains the integrity of your model's insights.
Performance considerations
Internal testing found JAX supercharged initial model runs, cutting average runtime by ~40% and memory usage by ~70%, compared to TensorFlow when using GPUs. JAX also streamlined model iterations, enabling 2x faster runtimes, 4x less memory usage, and uninterrupted workflows by eliminating the need for kernel restarts.
Because of the increased memory efficiency, you have more headroom to adjust
computationally intensive parameters. For example, in
Meridian.sample_posterior(), you might increase the unrolled_leapfrog_steps
argument (e.g., from 1 to 5). This can accelerate convergence by increasing the
trajectory length of the No-U-Turn-Sampler (NUTS) without exceeding hardware
memory limits. You can also increase the n_adapt parameter to further aid
convergence during the adaptation phase.