JAX-Backend verwenden

In diesem Leitfaden wird erläutert, wie Sie das JAX-Backend in Meridian verwenden.

Einführung in das JAX-Backend

Standardmäßig verwendet Meridian TensorFlow für die numerischen Kernoperationen und das probabilistische Markov-Chain-Monte-Carlo-Verfahren (MCMC) zur Stichprobenziehung. Das bietet eine robuste und gründlich getestete Grundlage für alle Modellierungsaufgaben.

Für Projekte, die von einer höheren Leistung und Arbeitsspeichereffizienz profitieren, bietet Meridian das JAX-Backend. JAX fördert einen funktionalen Programmierstil und nutzt die XLA-Kompilierung (Accelerated Linear Algebra), um erweiterte Leistungsoptimierungen zu ermöglichen.

Tutorial: Unter Getting started with JAX können Sie JAX in Aktion sehen.

JAX aktivieren

Die mathematischen Kernbibliotheken werden bei der Initialisierung geladen. Daher müssen Sie Meridian anweisen, JAX zu verwenden, bevor Sie Meridian-Module importieren.

Zum Aktivieren von JAX setzen Sie die Umgebungsvariable MERIDIAN_BACKEND auf 'jax'.

Sie müssen diese Umgebungsvariable in Ihrem Script festlegen, bevor Sie import meridian-Anweisungen ausführen:

import os

# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load

Doppelte Genauigkeit aktivieren

Für Modelle, bei denen die Konvergenz schwer zu erreichen ist, kann die Verwendung von doppelter Genauigkeit (64-Bit) mit dem JAX-Backend die numerische Stabilität verbessern. Doppelte Genauigkeit bietet zwar eine bessere numerische Stabilität, führt aber zu höherer Arbeitsspeichernutzung und längeren Berechnungszeiten. Daher bleibt die einfache Genauigkeit (32-Bit) für die meisten Anwendungsfälle der Standard. Um sie zu aktivieren, legen Sie die Umgebungsvariable MERIDIAN_ENABLE_JAX_X64 auf 'True' fest, bevor Sie Meridian importieren.

import os

# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'

# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model

Wird für die Umgebungsvariable MERIDIAN_BACKEND ein ungültiger String angegeben, gibt Meridian eine RuntimeWarning-Meldung aus und fällt automatisch auf die Standard-TensorFlow-Ausführung zurück. Wenn für die Umgebungsvariable MERIDIAN_ENABLE_JAX_X64 ein anderer Wert als „True“ oder „1“ angegeben wird, ist doppelte Genauigkeit nicht aktiviert und Meridian fällt standardmäßig auf die einfache Genauigkeit zurück.

Numerische Unterschiede und Reproduzierbarkeit

Da TensorFlow und JAX ihre Berechnungsgraphen unterschiedlich kompilieren, können bei der Umstellung auf JAX mit denselben Daten und Zufallswerten geringfügige numerische Unterschiede bei den Posterior-Schätzungen auftreten.

Auch wenn die Posterior-Verteilungen zwischen den Backends vielleicht nicht identisch sind, fallen Unterschiede in der Regel gering aus und sind für geschäftliche Messwerte wie den ROI und die Budgetzuweisung statistisch nicht signifikant. So bleibt die Integrität der Statistiken Ihres Modells beim Wechsel zum JAX-Backend erhalten.

Hinweise zur Leistung

Interne Tests haben ergeben, dass JAX initiale Modellläufe massiv beschleunigt: Bei der Verwendung von GPUs wurde die durchschnittliche Laufzeit im Vergleich zu TensorFlow um etwa 40 % und die Arbeitsspeichernutzung um etwa 70 % reduziert JAX hat außerdem die Modelliterationen optimiert und ermöglicht dadurch doppelt so schnelle Laufzeiten, eine um den Faktor 4 geringere Arbeitsspeichernutzung sowie unterbrechungsfreie Workflows, da Kernel-Neustarts entfallen.

Dank der erhöhten Arbeitsspeichereffizienz haben Sie mehr Spielraum, um rechenintensive Parameter anzupassen. In Meridian.sample_posterior() können Sie beispielsweise das Argument unrolled_leapfrog_steps erhöhen (etwa von 1 auf 5). Dies kann die Konvergenz beschleunigen, indem die Trajektorienlänge des No-U-Turn-Samplers (NUTS) erhöht wird, ohne die Hardware-Arbeitsspeicherlimits zu überschreiten. Sie können auch den Parameter n_adapt erhöhen, um die Konvergenz während der Anpassungsphase zu unterstützen.