Stay organized with collections
Save and categorize content based on your preferences.
Sandbox for playing with Sparse JAX and the sparse compiler on CPU.
See colab. Currently, it is only available to googlers.
In particular, this runs the colab [jaxpruner] Sparse Model ViT, written by Utku Evci.
Setup
import functools
import jax
# This is required to get Sparse JAX (experimental for now).
from jax.experimental import sparse
import jax.numpy as jnp
import os
from colabtools import adhoc_import
CITC_CLIENT = ""
CITC_USER = ""
CL_NUMBER = "HEAD"
# Use XLA-NEXT for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs.
# Eventually, it should be removed
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'
Importing modules from P4 HEAD.
JAXPRUNER Imports
import tensorflow as tf
import flax
import time
import matplotlib.pyplot as plt
import pandas as pd
JAXPRUNER: TF version = 2.15.0
JAXPRUNER Helper Methods
PRUNE_THRESHOLD = 0.06
def initialize_model(model, placeholder_input, rngs):
"""Initialization function to be jitted."""
init_model_state, init_params = model.flax_model.init(
rngs, placeholder_input, train=False, debug=False
).pop('params')
return init_params, init_model_state
def filtered_bcoo_simple(key, param):
if key[-1] == 'kernel' and 4 > param.ndim > 1:
# Aart made this pruning up, this has no ML semantics AFAIK.
# param = jnp.where(abs(param) < PRUNE_THRESHOLD, jnp.zeros(param.shape), param)
return sparse.BCOO.fromdense(param)
else:
return param
def total(val):
return functools.reduce((lambda x, y: x * y), val)
JAXPRUNER Model Initialization
config = imagenet_vit_config.get_config()
dataset_meta_data = {
'input_dtype': jax.numpy.float32,
'input_shape': (-1, 224, 224, 3),
'num_classes': 1000,
'num_eval_examples': 50000,
'num_train_examples': 1281167,
'target_is_onehot': False,
}
model_cls = models.get_model_cls(config.model_name)
model = model_cls(config, dataset_meta_data)
rng, init_rng = jax.random.split(jax.random.PRNGKey(8))
# Initialize model.
rng, init_rng = jax.random.split(rng)
placeholder_input = jnp.ones((1, 224, 224, 3))
init_params, init_model_state = initialize_model(
model, placeholder_input, {'params': init_rng}
)
initial_train_state = train_utils.TrainState(
global_step=0, params=init_params, model_state=init_model_state, rng=rng
)
# Uses hardcoded cns for testing.
init_checkpoint_path = (
'/cns/vz-d/home/brain-sparsity/evcu/sparsevit-checkpoints/sparse_b16_80/'
)
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
init_checkpoint_path, initial_train_state, assert_exist=True
)
res_dict = {}
for k, p in flax.traverse_util.flatten_dict(restored_train_state.params).items():
res_dict[k] = filtered_bcoo_simple(k, p)
sparse_params = flax.traverse_util.unflatten_dict(res_dict)
i = 0
for k, v in flax.traverse_util.flatten_dict(sparse_params).items():
if isinstance(v, sparse.BCOO):
density = v.nse / total(v.shape)
print('JAXPRUNER: BCOO @', i, k[-2:], v.shape, v.nse, density)
i = i + 1
# Get the variables.
variables = {'params': restored_train_state.params, **init_model_state}
sp_variables = {'params': sparse_params, **init_model_state}
print('JAXPRUNER: ready to run from', init_checkpoint_path)
JAXPRUNER Dense Run
start_time = time.time()
lowered = jax.jit(
functools.partial(model.flax_model.apply, train=False)
).lower(variables, jnp.ones((1, 224, 224, 3)))
compiled = lowered.compile()
dense_c_time = time.time() - start_time
print('JAXPRUNER: dense compiled %s seconds ---' % (dense_c_time))
start_time = time.time()
res_dense = compiled(variables, jnp.ones((1, 224, 224, 3)))
dense_r_time = time.time() - start_time
print('JAXPRUNER: dense run %s seconds ---' % (dense_r_time))
JAXPRUNER: dense compiled 11.863374471664429 seconds ---
JAXPRUNER: dense run 58.369417905807495 seconds ---
JAXPRUNER Sparse Run (using pure JAXPR)
start_time = time.time()
sparse_lower = (
jax.jit(
sparse.sparsify(
functools.partial(model.flax_model.apply, train=False)
)
)
.lower(sp_variables, jnp.ones((1, 224, 224, 3)))
)
sparse_apply = sparse_lower.compile()
sparse_jaxpr_c_time = time.time() - start_time
print('JAXPRUNER: sparse JAXPR compiled %s seconds ---' % (sparse_jaxpr_c_time))
start_time = time.time()
res_jaxpr = sparse_apply(sp_variables, jnp.ones((1, 224, 224, 3)))
sparse_jaxpr_r_time = time.time() - start_time
print('JAXPRUNER: sparse JAXPR run %s seconds ---' % (sparse_jaxpr_r_time))
JAXPRUNER: sparse JAXPR compiled 43.14039134979248 seconds ---
JAXPRUNER: sparse JAXPR run 12.49958348274231 seconds ---
JAXPRUNER Sparse Run (using XLA-Next Sparse Compiler)
start_time = time.time()
sparse_lower = (
sparse_jit(functools.partial(model.flax_model.apply, train=False)
)
.lower(sp_variables, jnp.ones((1, 224, 224, 3)))
)
sparse_apply = sparse_lower.compile()
sparse_xla_next_c_time = time.time() - start_time
print('JAXPRUNER: sparse XLA-NEXT compiled %s seconds ---' % (sparse_xla_next_c_time))
start_time = time.time()
res_xla_next = sparse_apply(sp_variables, jnp.ones((1, 224, 224, 3)))
sparse_xla_next_r_time = time.time() - start_time
print('JAXPRUNER: sparse XLA-NEXT run %s seconds ---' % (sparse_xla_next_r_time))
JAXPRUNER: sparse XLA-NEXT compiled 15.409374713897705 seconds ---
JAXPRUNER: sparse XLA-NEXT run 9.1929030418396 seconds ---
JAXPRUNER Results
axis = ['dense', 'sparse_jaxpr', 'sparse_xla_next']
data1 = [dense_c_time, sparse_jaxpr_c_time, sparse_xla_next_c_time]
data2 = [dense_r_time, sparse_jaxpr_r_time, sparse_xla_next_r_time]
df = pd.DataFrame([data1, data2], index=['compile time','runtime'], columns=axis)
df.plot.bar(rot=0).set_title("JaxPruner on CPU")
Text(0.5, 1.0, 'JaxPruner on CPU')

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-03 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-03 UTC."],[[["\u003cp\u003eThis webpage provides a sandbox environment for experimenting with Sparse JAX and its sparse compiler on CPU, specifically focusing on the JaxPruner Sparse Model ViT.\u003c/p\u003e\n"],["\u003cp\u003eThe sandbox uses a pre-trained ViT model and demonstrates how to convert it to a sparse format using BCOO (Block-sparse COO) representation.\u003c/p\u003e\n"],["\u003cp\u003eIt showcases and benchmarks three different runtime scenarios: dense, sparse using JAXPR, and sparse using the XLA-Next sparse compiler.\u003c/p\u003e\n"],["\u003cp\u003eThe benchmark results, visualized in a bar chart, compare compilation and runtime performance for each scenario, indicating the potential benefits of using sparse representations and the sparse compiler.\u003c/p\u003e\n"],["\u003cp\u003eCurrently, the sandbox is only accessible to Googlers through a provided Colab link.\u003c/p\u003e\n"]]],["The document outlines using Sparse JAX and its compiler on CPU for a Sparse Model ViT. Key actions include setting up the environment with specific XLA flags, initializing the model, and restoring a pre-trained checkpoint. The code defines a function for sparse BCOO conversion. Three runs are then timed: a dense run, a sparse run with JAXPR, and a sparse run with XLA-Next. These runs compare compilation and execution times, which are finally plotted on a graph.\n"],null,["Sandbox for playing with Sparse JAX and the sparse compiler on CPU.\n\nSee [colab](https://colab.corp.google.com/drive/1YvtreJb07lkd1l9gZ06qzx_70oPuzWe1). Currently, it is only available to googlers.\n\n*** ** * ** ***\n\nIn particular, this runs the colab \\[jaxpruner\\] Sparse Model ViT, written by Utku Evci.\n\n### Setup\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n import functools\n import jax\n\n # This is required to get Sparse JAX (experimental for now).\n from jax.experimental import sparse\n\n import jax.numpy as jnp\n import os\n\n from colabtools import adhoc_import\n\n\n\n CITC_CLIENT = \"\" \n\n CITC_USER = \"\" \n\n CL_NUMBER = \"HEAD\"\n\n # Use XLA-NEXT for fast sparse ops\n # --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops\n # --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs.\n # Eventually, it should be removed\n os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nImporting modules from P4 HEAD.\n```\n\n### JAXPRUNER Imports\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n import tensorflow as tf\n\n import flax\n import time\n\n import matplotlib.pyplot as plt\n import pandas as pd\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nJAXPRUNER: TF version = 2.15.0\n```\n\n### JAXPRUNER Helper Methods\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n PRUNE_THRESHOLD = 0.06\n\n def initialize_model(model, placeholder_input, rngs):\n \"\"\"Initialization function to be jitted.\"\"\"\n init_model_state, init_params = model.flax_model.init(\n rngs, placeholder_input, train=False, debug=False\n ).pop('params')\n return init_params, init_model_state\n\n def filtered_bcoo_simple(key, param):\n if key[-1] == 'kernel' and 4 \u003e param.ndim \u003e 1:\n # Aart made this pruning up, this has no ML semantics AFAIK.\n # param = jnp.where(abs(param) \u003c PRUNE_THRESHOLD, jnp.zeros(param.shape), param)\n return sparse.BCOO.fromdense(param)\n else:\n return param\n\n def total(val):\n return functools.reduce((lambda x, y: x * y), val)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n### JAXPRUNER Model Initialization\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n config = imagenet_vit_config.get_config()\n\n dataset_meta_data = {\n 'input_dtype': jax.numpy.float32,\n 'input_shape': (-1, 224, 224, 3),\n 'num_classes': 1000,\n 'num_eval_examples': 50000,\n 'num_train_examples': 1281167,\n 'target_is_onehot': False,\n }\n\n model_cls = models.get_model_cls(config.model_name)\n model = model_cls(config, dataset_meta_data)\n rng, init_rng = jax.random.split(jax.random.PRNGKey(8))\n\n # Initialize model.\n rng, init_rng = jax.random.split(rng)\n placeholder_input = jnp.ones((1, 224, 224, 3))\n\n init_params, init_model_state = initialize_model(\n model, placeholder_input, {'params': init_rng}\n )\n initial_train_state = train_utils.TrainState(\n global_step=0, params=init_params, model_state=init_model_state, rng=rng\n )\n\n # Uses hardcoded cns for testing.\n init_checkpoint_path = (\n '/cns/vz-d/home/brain-sparsity/evcu/sparsevit-checkpoints/sparse_b16_80/'\n )\n\n restored_train_state = pretrain_utils.restore_pretrained_checkpoint(\n init_checkpoint_path, initial_train_state, assert_exist=True\n )\n\n res_dict = {}\n for k, p in flax.traverse_util.flatten_dict(restored_train_state.params).items():\n res_dict[k] = filtered_bcoo_simple(k, p)\n\n sparse_params = flax.traverse_util.unflatten_dict(res_dict)\n\n i = 0\n for k, v in flax.traverse_util.flatten_dict(sparse_params).items():\n if isinstance(v, sparse.BCOO):\n density = v.nse / total(v.shape)\n print('JAXPRUNER: BCOO @', i, k[-2:], v.shape, v.nse, density)\n i = i + 1\n\n # Get the variables.\n variables = {'params': restored_train_state.params, **init_model_state}\n sp_variables = {'params': sparse_params, **init_model_state}\n\n print('JAXPRUNER: ready to run from', init_checkpoint_path)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n### JAXPRUNER Dense Run\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n start_time = time.time()\n lowered = jax.jit(\n functools.partial(model.flax_model.apply, train=False)\n ).lower(variables, jnp.ones((1, 224, 224, 3)))\n compiled = lowered.compile()\n dense_c_time = time.time() - start_time\n print('JAXPRUNER: dense compiled %s seconds ---' % (dense_c_time))\n\n start_time = time.time()\n res_dense = compiled(variables, jnp.ones((1, 224, 224, 3)))\n dense_r_time = time.time() - start_time\n print('JAXPRUNER: dense run %s seconds ---' % (dense_r_time))\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nJAXPRUNER: dense compiled 11.863374471664429 seconds ---\nJAXPRUNER: dense run 58.369417905807495 seconds ---\n```\n\n### JAXPRUNER Sparse Run (using pure JAXPR)\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n start_time = time.time()\n sparse_lower = (\n jax.jit(\n sparse.sparsify(\n functools.partial(model.flax_model.apply, train=False)\n )\n )\n .lower(sp_variables, jnp.ones((1, 224, 224, 3)))\n )\n sparse_apply = sparse_lower.compile()\n sparse_jaxpr_c_time = time.time() - start_time\n print('JAXPRUNER: sparse JAXPR compiled %s seconds ---' % (sparse_jaxpr_c_time))\n\n start_time = time.time()\n res_jaxpr = sparse_apply(sp_variables, jnp.ones((1, 224, 224, 3)))\n sparse_jaxpr_r_time = time.time() - start_time\n print('JAXPRUNER: sparse JAXPR run %s seconds ---' % (sparse_jaxpr_r_time))\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nJAXPRUNER: sparse JAXPR compiled 43.14039134979248 seconds ---\nJAXPRUNER: sparse JAXPR run 12.49958348274231 seconds ---\n```\n\n### JAXPRUNER Sparse Run (using XLA-Next Sparse Compiler)\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n start_time = time.time()\n sparse_lower = (\n sparse_jit(functools.partial(model.flax_model.apply, train=False)\n )\n .lower(sp_variables, jnp.ones((1, 224, 224, 3)))\n )\n sparse_apply = sparse_lower.compile()\n sparse_xla_next_c_time = time.time() - start_time\n print('JAXPRUNER: sparse XLA-NEXT compiled %s seconds ---' % (sparse_xla_next_c_time))\n\n start_time = time.time()\n res_xla_next = sparse_apply(sp_variables, jnp.ones((1, 224, 224, 3)))\n sparse_xla_next_r_time = time.time() - start_time\n print('JAXPRUNER: sparse XLA-NEXT run %s seconds ---' % (sparse_xla_next_r_time))\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nJAXPRUNER: sparse XLA-NEXT compiled 15.409374713897705 seconds ---\nJAXPRUNER: sparse XLA-NEXT run 9.1929030418396 seconds ---\n```\n\n### JAXPRUNER Results\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n axis = ['dense', 'sparse_jaxpr', 'sparse_xla_next']\n data1 = [dense_c_time, sparse_jaxpr_c_time, sparse_xla_next_c_time]\n data2 = [dense_r_time, sparse_jaxpr_r_time, sparse_xla_next_r_time]\n df = pd.DataFrame([data1, data2], index=['compile time','runtime'], columns=axis)\n\n df.plot.bar(rot=0).set_title(\"JaxPruner on CPU\")\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nText(0.5, 1.0, 'JaxPruner on CPU')\n```"]]