Sandbox for playing with Sparse JAX and the sparse compiler on CPU.

See colab. Currently, it is only available to googlers.


In particular, this runs the colab [jaxpruner] Sparse Model ViT, written by Utku Evci.

Setup

Importing modules from P4 HEAD.

JAXPRUNER Imports

JAXPRUNER: TF version =  2.15.0

JAXPRUNER Helper Methods

JAXPRUNER Model Initialization

JAXPRUNER Dense Run

JAXPRUNER: dense compiled 11.863374471664429 seconds ---
JAXPRUNER: dense run 58.369417905807495 seconds ---

JAXPRUNER Sparse Run (using pure JAXPR)

JAXPRUNER: sparse JAXPR compiled 43.14039134979248 seconds ---
JAXPRUNER: sparse JAXPR run 12.49958348274231 seconds ---

JAXPRUNER Sparse Run (using XLA-Next Sparse Compiler)

JAXPRUNER: sparse XLA-NEXT compiled 15.409374713897705 seconds ---
JAXPRUNER: sparse XLA-NEXT run 9.1929030418396 seconds ---

JAXPRUNER Results

Text(0.5, 1.0, 'JaxPruner on CPU')

png