This guide explains how to use the JAX backend in Meridian.
Introduction to the JAX backend
By default, Meridian uses TensorFlow for its core numerical operations and probabilistic Markov Chain Monte Carlo (MCMC) sampling, providing a robust and thoroughly tested foundation for all modeling tasks.
For projects that benefit from enhanced performance and memory efficiency, Meridian provides the JAX backend. JAX encourages a functional programming style and utilizes XLA (Accelerated Linear Algebra) compilation to offer advanced performance optimizations.
Tutorial: To see JAX in action, see the Getting started with JAX notebook.
How to enable JAX
Because the core mathematical libraries are loaded at initialization, you must instruct Meridian to use JAX before importing any Meridian modules.
To enable JAX, set the MERIDIAN_BACKEND environment variable to 'jax'.
You must set this environment variable in your script before executing any
import meridian statements:
import os
# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load
Enable 64-bit precision
For models where convergence is difficult to achieve, using 64-bit precision
with the JAX backend can provide improved numerical stability. While 64-bit
precision offers better numerical stability, it comes with increased memory
usage and slower computation times. Therefore, 32-bit precision remains the
default for most use cases. To enable it, set the MERIDIAN_ENABLE_JAX_X64
environment variable to 'True' before importing Meridian.
import os
# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'
# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
If an invalid string is provided to the MERIDIAN_BACKEND environment variable,
Meridian will issue a RuntimeWarning and default back to standard TensorFlow
execution. If a value other than 'True' or '1' is provided to the
MERIDIAN_ENABLE_JAX_X64 environment variable, 64-bit precision is not enabled,
and Meridian defaults to 32-bit precision.
API differences when using JAX versus Tensorflow
When transitioning from TensorFlow to the JAX backend, there are key API differences you need to accommodate in your code:
Prior distributions
If you are setting custom prior distributions using PriorDistribution, you
must use tfp.substrates.jax.distributions instead of tfp.distributions. For
example:
TensorFlow
import tensorflow_probability as tfp
from meridian.model import prior_distribution
prior = prior_distribution.PriorDistribution(
roi_m=tfp.distributions.LogNormal(0.2, 0.9)
)
JAX
import tensorflow_probability as tfp
tfp_jax = tfp.substrates.jax
from meridian.model import prior_distribution
prior = prior_distribution.PriorDistribution(
roi_m=tfp_jax.distributions.LogNormal(0.2, 0.9)
)
Explicit seed requirement
When using the JAX backend, an explicit seed is required for stochastic
functions (for example, in sample_posterior()). While TensorFlow uses a global
random number generator that automatically picks a random seed, JAX makes this
seed explicit. We found no statistically significant differences in ROI
estimates or budget shifts across different seeds.
# Explicitly set a seed for MCMC sampling when using the JAX backend
mmm.sample_posterior(
n_chains=2,
n_adapt=1000,
n_burnin=500,
n_keep=1000,
seed=0,
)
For more information on JAX random numbers and seeds, refer to the JAX pseudorandom numbers documentation.
Numerical differences and reproducibility
Because TensorFlow and JAX compile their computational graphs differently, you may observe minor numerical differences in your posterior estimates when switching to JAX using the same data and random seeds.
While posterior distributions might not be identical across backends, the differences are generally small and not statistically significant for business metrics such as ROI and budget allocation. This ensures that switching to the JAX backend maintains the integrity of your model's insights.
Performance considerations
Internal testing found JAX supercharged initial model runs, cutting average runtime by ~40% and memory usage by ~70%, compared to TensorFlow when using GPUs. JAX also streamlined model iterations, enabling 2x faster runtimes, 4x less memory usage, and uninterrupted workflows by eliminating the need for kernel restarts.
Because of the increased memory efficiency, you have more headroom to adjust
computationally intensive parameters. For example, in
Meridian.sample_posterior(), you might increase the unrolled_leapfrog_steps
argument (e.g., from 1 to 5). This can accelerate convergence by increasing the
trajectory length of the No-U-Turn-Sampler (NUTS) without exceeding hardware
memory limits. You can also increase the n_adapt parameter to further aid
convergence during the adaptation phase.