使用 JAX 后端

本指南介绍了如何在 Meridian 中使用 JAX 后端。

JAX 后端简介

默认情况下,Meridian 采用 TensorFlow 处理核心数值运算和概率性马尔可夫链蒙特卡洛 (MCMC) 抽样,为各类建模任务提供稳健且经过全面测试的基础。

对于那些能从更高性能和内存效率中获益的项目,Meridian 提供了 JAX 后端。JAX 提倡函数式编程风格,并利用 XLA(加速线性代数)编译技术来实现高级性能优化。

教程:如需查看 JAX 的实际应用,请参阅开始使用 JAX 笔记本。

如何启用 JAX

由于核心数学库会在初始化时加载,因此在导入任何 Meridian 模块之前,您必须先指示 Meridian 使用 JAX。

如需启用 JAX,请将 MERIDIAN_BACKEND 环境变量设置为 'jax'

您必须在脚本中设置此环境变量,然后才能执行任何 import meridian 语句:

import os

# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load

启用 64 位精度

对于难以收敛的模型,在 JAX 后端下采用 64 位精度可提升数值稳定性;不过,这会增加内存用量并拖慢计算速度。因此,在大多数应用场景下,系统仍默认采用 32 位精度。如需启用 64 位精度,请在导入 Meridian 之前将 MERIDIAN_ENABLE_JAX_X64 环境变量设置为 'True'

import os

# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'

# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model

若为 MERIDIAN_BACKEND 环境变量提供了无效字符串,Meridian 将发出 RuntimeWarning,并自动回退到标准的 TensorFlow 执行模式。若为 MERIDIAN_ENABLE_JAX_X64 环境变量提供的值不是“True”或“1”,则系统不会启用 64 位精度,Meridian 将默认采用 32 位精度。

数值差异和可复现性

由于 TensorFlow 和 JAX 采用不同的计算图编译机制,因此在切换至 JAX 后端时,即使数据和随机种子完全一致,后验估计值仍可能存在细微的数值差异。

虽然不同后端生成的后验分布可能不尽相同,但这些差异通常微乎其微,对于投资回报率和预算分配等业务指标而言,并不具备统计显著性。这可以确保切换至 JAX 后端时,模型生成的数据洞见依然完整可靠。

性能考虑因素

内部测试发现,在使用 GPU 的环境下,相较于 TensorFlow,JAX 可大幅提升模型初始运行效率,平均运行时间缩短约 40%,内存用量降低约 70%。JAX 还优化了模型迭代流程,运行速度提升至原来的 2 倍,内存用量减少四分之三;此外,由于无需重启内核,可确保工作流不间断运行。

由于内存效率显著提升,您将拥有更大的余量来调优那些计算密集型参数。例如,在 Meridian.sample_posterior() 中,您可以调高 unrolled_leapfrog_steps 参数(例如,从 1 增至 5)。这有助于在不超出硬件内存限制的前提下,通过增加 No-U-Turn-Sampler (NUTS) 的轨迹长度来加速收敛。您还可以调高 n_adapt 参数,在自适应阶段进一步辅助模型收敛。