Sandbox for playing with Sparse JAX and the sparse compiler on CPU.
See colab. Currently, it is only available to googlers.
This colab can be executed on Brain/DeepMind CPU runtime.
We need to use XLA-NEXT:CPU for fast sparse ops: the next cell shows how to set it up, and it allows to import sparse_jit either from P4 HEAD (the default) or from a CitC workspace (this only picks up Python related changes though).
You must restart the runtime before re-running the cell below to pick up changes:
- soft restart: go to Runtime, Restart Runtime
- hard restart: go to Connect to a Borg Runtime, Stop Runtime, Start New
Setup
import functools
import jax
from jax.experimental import sparse # this gives you sparse JAX
import jax.numpy as jnp
import os
from colabtools import adhoc_import
CITC_CLIENT = ""
CITC_USER = ""
CL_NUMBER = "HEAD"
# Use XLA-NEXT for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'
Importing modules from P4 HEAD.
Hello, sparse world!
data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype='float32')
indices = jnp.array([[0], [2], [4], [20], [500], [999]], dtype='int32')
vec_bcoo = sparse.BCOO((data, indices), shape=(1000,))
print('Hello sparse world')
print(vec_bcoo, vec_bcoo.__class__)
@sparse_jit
def foo(x):
return jnp.dot(x, x)
ret = foo(vec_bcoo)
print('Computed', ret)
Hello sparse world BCOO(float32[1000], nse=6) <class 'jax.experimental.sparse.bcoo.BCOO'> Computed 91.0
Profile matmul vs sparsity
import matplotlib.pyplot as plt
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
# Try reversing this shape to see how the relative difference in compute time
# between dense and sparse matmul changes.
DENSE_SHAPE = (8, 10_000)
#DENSE_SHAPE = (10_000, 8)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def matmul(m1, m2):
return jnp.matmul(m1, m2)
matmul_jit = jax.jit(matmul)
sparse_matmul_jit = sparse_jit(matmul, output_format=None) # One of None, 'dense', 'sparse'
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_entries_sparse = []
t_us_dense = []
t_us_sparse_bcoo = []
t_us_sparse_bcsr = []
# This runs more data points in the sparse region.
# for i in range(11):
# fraction_non_zero = 10**-(0.5*(10-i))
# This runs data points uniformly.
for i in range(10):
fraction_non_zero = i / 10.0
m = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero,
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE,
)
)
mT = m.T
# Note that the jax.experimental.sparse.BCOO.fromdense() method produces
# a sorted COO as currently always expected by the code generated by the
# sparse compiler. The operation sm.T, on the other hand, uses the
# jax.sparse implementation which does not result in sorted COO, unless
# we explicitly sparse_jit that operation too first. So, in general, for
# the time being, make sure that all operations are handled by sparse_jit,
# or otherwise only rely on the from_dense() method for BCOO construction.
sm = jax.experimental.sparse.BCOO.fromdense(m)
smT = jax.experimental.sparse.BCOO.fromdense(mT)
s2m = jax.experimental.sparse.BCSR.fromdense(m)
s2mT = jax.experimental.sparse.BCSR.fromdense(mT)
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
# Dense
try:
op = lambda: matmul_jit(m, mT).block_until_ready()
dense_matmul_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY matmul completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense matmul failed with error: {e}')
# Sparse BCOO
try:
n_entries_sparse.append(sm.nse)
op = lambda: sparse_matmul_jit(sm, smT).block_until_ready()
sparse_matmul_result = op() # Warmup
# Verify that dense and sparse path compute the same result.
# Note that the result returned by sparse_matmul_jit is dense.
assert jnp.allclose(dense_matmul_result, sparse_matmul_result)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcoo.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCOO matmul completed in {int(t_us_sparse_bcoo[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Sparse BCOO matmul failed with error: {e}')
t_us_sparse_bcoo.append(0)
# Sparse BCSR
try:
op = lambda: sparse_matmul_jit(s2m, s2mT).block_until_ready()
sparse_matmul_result = op() # Warmup
# Verify that dense and sparse path compute the same result.
# Note that the result returned by sparse_matmul_jit is dense.
assert jnp.allclose(dense_matmul_result, sparse_matmul_result)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcsr.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCSR matmul completed in {int(t_us_sparse_bcsr[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Sparse BCSR matmul failed with error: {e}')
t_us_sparse_bcsr.append(0)
print()
density = [n_entries_sparse[i] / N_DENSE_ENTRIES for i in range(len(n_entries_sparse))]
plt.figure(figsize=(10,7))
plt.plot(density, t_us_dense, 'o-')
plt.plot(density, t_us_sparse_bcoo, 'o-')
plt.plot(density, t_us_sparse_bcsr, 'o-')
plt.title(f'{DENSE_SHAPE} x {tuple(reversed(DENSE_SHAPE))} matmul time in microseconds vs density ({jax.devices()[0].platform}).');
plt.legend(['Dense matmul', 'Sparse matmul (BCOO)', 'Sparse matmul (BCSR)']);
plt.xlabel('Fraction of non-zero entries');
plt.ylabel('microseconds');
plt.yscale('log');
[ 0] Fraction of non-zero entries: 0.00000 Dense ARRAY matmul completed in 601us Sparse BCOO matmul completed in 17us Sparse BCSR matmul completed in 16us [ 1] Fraction of non-zero entries: 0.10000 Dense ARRAY matmul completed in 590us Sparse BCOO matmul completed in 282us Sparse BCSR matmul completed in 72us [ 2] Fraction of non-zero entries: 0.20000 Dense ARRAY matmul completed in 594us Sparse BCOO matmul completed in 548us Sparse BCSR matmul completed in 148us [ 3] Fraction of non-zero entries: 0.30000 Dense ARRAY matmul completed in 597us Sparse BCOO matmul completed in 717us Sparse BCSR matmul completed in 215us [ 4] Fraction of non-zero entries: 0.40000 Dense ARRAY matmul completed in 590us Sparse BCOO matmul completed in 818us Sparse BCSR matmul completed in 295us [ 5] Fraction of non-zero entries: 0.50000 Dense ARRAY matmul completed in 590us Sparse BCOO matmul completed in 898us Sparse BCSR matmul completed in 374us [ 6] Fraction of non-zero entries: 0.60000 Dense ARRAY matmul completed in 596us Sparse BCOO matmul completed in 965us Sparse BCSR matmul completed in 457us [ 7] Fraction of non-zero entries: 0.70000 Dense ARRAY matmul completed in 591us Sparse BCOO matmul completed in 1,001us Sparse BCSR matmul completed in 530us [ 8] Fraction of non-zero entries: 0.80000 Dense ARRAY matmul completed in 596us Sparse BCOO matmul completed in 1,009us Sparse BCSR matmul completed in 598us [ 9] Fraction of non-zero entries: 0.90000 Dense ARRAY matmul completed in 595us Sparse BCOO matmul completed in 972us Sparse BCSR matmul completed in 658us
Replicate matmul results from previous cell using nn.Dense model
import flax
from flax import linen as nn
import matplotlib.pyplot as plt
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
# Try reversing this shape to see how the relative difference in compute time
# between dense and sparse matmul changes.
DENSE_SHAPE = (8, 10_000)
# DENSE_SHAPE = (10_000, 8)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def matmul(m1, m2):
return jnp.matmul(m1, m2)
matmul_jit = jax.jit(matmul)
sparse_matmul_jit = sparse_jit(matmul)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_entries_sparse = []
t_us_dense = []
t_us_sparse = []
for i in range(11):
fraction_non_zero = 10**-(0.5*(10-i))
m = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero,
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE,
)
)
mT = m.T
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
model = nn.Sequential(
[
nn.Dense(features=DENSE_SHAPE[0], use_bias=False),
]
)
params = model.init(
{
'params': subkey,
},
m,
)
# Replace randomly initialized params with mT
params = jax.tree_util.tree_map(
lambda p: mT,
params
)
model_fn = functools.partial(model.apply)
# Now the model is equivalent to the matrix multiplication from the previous
# section.
assert jnp.allclose(matmul_jit(m, mT), model_fn(params, m))
# Dense
try:
dense_model_fn_jit = jax.jit(model_fn)
op = lambda: dense_model_fn_jit(params, m).block_until_ready()
dense_op_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * min(t) / LOOPS)
print(f' Dense model completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense model failed with error: {e}')
# Sparse
try:
sm = jax.experimental.sparse.BCOO.fromdense(m)
sparse_params = jax.tree_map(
lambda x: jax.experimental.sparse.BCOO.fromdense(x),
params,
)
sparse_model_fn_jit = sparse_jit(model_fn)
op = lambda: sparse_model_fn_jit(sparse_params, sm).block_until_ready()
sparse_op_result = op() # Warmup
# Note that the result returned by sparse_matmul_jit is a dense matrix at
# the moment.
assert jnp.allclose(dense_op_result, sparse_op_result)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse.append(SECONDS_TO_MICROS * min(t) / LOOPS)
n_entries_sparse.append(sm.nse)
print(f' Sparse model completed in {int(t_us_sparse[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Sparse model failed with error: {e}')
print()
density = [n_entries_sparse[i] / N_DENSE_ENTRIES for i in range(len(n_entries_sparse))]
plt.figure(figsize=(10,7))
plt.plot(density, t_us_dense, 'o-')
plt.plot(density, t_us_sparse, 'o-')
plt.title(f'{DENSE_SHAPE} x {tuple(reversed(DENSE_SHAPE))} matmul time in microseconds vs density ({jax.devices()[0].platform}).');
plt.legend(['Dense model', 'Sparse model (BCOO)']);
plt.xlabel('Fraction of non-zero entries');
plt.ylabel('microseconds');
plt.yscale('log');
[ 0] Fraction of non-zero entries: 0.00001 Dense model completed in 601us Sparse model completed in 17us [ 1] Fraction of non-zero entries: 0.00003 Dense model completed in 601us Sparse model completed in 15us [ 2] Fraction of non-zero entries: 0.00010 Dense model completed in 608us Sparse model completed in 15us [ 3] Fraction of non-zero entries: 0.00032 Dense model completed in 598us Sparse model completed in 16us [ 4] Fraction of non-zero entries: 0.00100 Dense model completed in 597us Sparse model completed in 18us [ 5] Fraction of non-zero entries: 0.00316 Dense model completed in 593us Sparse model completed in 21us [ 6] Fraction of non-zero entries: 0.01000 Dense model completed in 612us Sparse model completed in 37us [ 7] Fraction of non-zero entries: 0.03162 Dense model completed in 602us Sparse model completed in 89us [ 8] Fraction of non-zero entries: 0.10000 Dense model completed in 613us Sparse model completed in 310us [ 9] Fraction of non-zero entries: 0.31623 Dense model completed in 596us Sparse model completed in 796us [10] Fraction of non-zero entries: 1.00000 Dense model completed in 597us Sparse model completed in 671us
Sparse Jax Specifics
print('Sparse JAX BCOO support sparsification of:')
print(sparse.transform.sparse_rules_bcoo.keys())
print('Sparse JAX BCCSR support sparsification of:')
print(sparse.transform.sparse_rules_bcsr.keys())
Ad = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype='float32')
As = sparse.BCOO.fromdense(Ad)
def bar(x, y):
return jnp.add(x, y)
#return jnp.dot(x, y)
print(jax.make_jaxpr(bar)(Ad, Ad))
print(jax.make_jaxpr(sparse.sparsify(bar))(Ad, As))
print(jax.make_jaxpr(sparse.sparsify(bar))(As, As))
Sparse JAX BCOO support sparsification of: dict_keys([abs, asin, asinh, atan, atanh, bessel_i1e, expm1, log1p, sign, sin, sinh, sqrt, tan, tanh, convert_element_type, copy, imag, neg, real, broadcast_in_dim, concatenate, conv_general_dilated, dot_general, dynamic_slice, reshape, rev, slice, squeeze, integer_pow, transpose, add, sub, mul, div, reduce_sum, gather, while, pjit, scan, cond, todense, custom_jvp_call]) Sparse JAX BCCSR support sparsification of: dict_keys([abs, asin, asinh, atan, atanh, bessel_i1e, expm1, log1p, sign, sin, sinh, sqrt, tan, tanh, convert_element_type, copy, imag, neg, real, dot_general, broadcast_in_dim, concatenate, integer_pow, todense, custom_jvp_call]) { lambda ; a:f32[6] b:f32[6]. let c:f32[6] = add a b in (c,) } { lambda ; a:f32[6] b:f32[6] c:i32[6,1]. let d:f32[6] = bcoo_todense[ spinfo=SparseInfo(shape=(6,), indices_sorted=True, unique_indices=True) ] b c e:f32[6] = add d a in (e,) } { lambda ; a:f32[6] b:i32[6,1] c:f32[6] d:i32[6,1]. let e:i32[12,1] = concatenate[dimension=0] b d f:f32[12] = concatenate[dimension=0] a c in (f, e) }