Sandbox for playing with Sparse JAX and the sparse compiler on GPU.
See colab. Currently, it is only available to googlers.
This colab can be executed on a special XLA-NEXT:CPU+ enabled GPU runtime.
The XLA-NEXT:CPU+ compiler (the plus for GPU!) is needed to accelerate sparse operations on the GPU. The obtain this compiler requires running from a custom-built colab binary, since the "GPU part" of XLA-NEXT is not in production.
Use the cell below to select between CUDA libgen and CUDA codegen for sparsified code (toggling this always needs a hard reset of the runtime):
- xla_cpu_sparse_cuda_threads=1 CUDA LIBGEN, e.g. use cuSPARSE
- xla_cpu_sparse_cuda_threads>1 CUDA CODEGEN, e.g. use cuda threads
Setup
import functools
import jax
from jax.experimental import sparse # this gives you sparse JAX
import jax.numpy as jnp
import os
from colabtools import adhoc_import
# Use XLA-NEXT:CPU+ for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs
# --xla_cpu_sparse_cuda_threads enables the GPU part of XLA-NEXT:CPU+
#
# Here, the cuda threads value further controls the following:
# xla_cpu_sparse_cuda_threads=1 CUDA LIBGEN, e.g. use cuSPARSE
# xla_cpu_sparse_cuda_threads>1 CUDA CODEGEN, e.g. use cuda threads
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false --xla_cpu_sparse_cuda_threads=1'
Importing modules from P4 HEAD.
Profile SDDMM on GPU
import matplotlib.pyplot as plt
import numpy as np
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
#DENSE_SHAPE = (128, 128)
DENSE_SHAPE = (1024, 1024)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def sddmm(s, a, b):
return s * (a @ b)
sddmm_jit = jax.jit(sddmm)
sparse_sddmm_jit = sparse_jit(sddmm)
sparse_custom_sddmm_jit = sparse_jit(sparse_jax_sddmm)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_entries_sparse = []
t_us_dense = []
t_us_sparse_bcoo = []
t_us_sparse_bcsr = []
t_us_sparse_custom_bcsr = []
for i in range(10):
fraction_non_zero = i / 10.0
s = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE
)
)
a = jnp.ones(DENSE_SHAPE, dtype=jnp.float32)
b = jnp.ones(DENSE_SHAPE, dtype=jnp.float32)
sm_coo = jax.experimental.sparse.BCOO.fromdense(s)
sm_csr = jax.experimental.sparse.BCSR.fromdense(s)
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
# Dense
try:
op = lambda: sddmm_jit(s, a, b).block_until_ready()
dense_sddmm_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY SDDMM completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense ARRAY SDDMM failed with error: {e}')
t_us_dense.append(0)
# Sparse BCOO
try:
n_entries_sparse.append(sm_coo.nse)
op = lambda: sparse_sddmm_jit(sm_coo, a, b).block_until_ready()
sparse_sddmm_result_coo = op() # Warmup
# Verify that dense and sparse path compute the same result.
assert jnp.allclose(dense_sddmm_result, sparse_sddmm_result_coo.todense(), rtol=1e-2)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcoo.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCOO SDDMM completed in {int(t_us_sparse_bcoo[-1]):8,}us {sm_coo.nse} nse')
del op
del t
except BaseException as e:
print(f' Sparse BCOO SDDMM failed with error: {e}')
t_us_sparse_bcoo.append(0)
# Sparse BCSR
try:
op = lambda: sparse_sddmm_jit(sm_csr, a, b).block_until_ready()
sparse_sddmm_result_csr = op() # Warmup
# Verify that dense and sparse path compute the same result.
assert jnp.allclose(dense_sddmm_result, sparse_sddmm_result_csr.todense(), rtol=1e-2)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcsr.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCSR SDDMM completed in {int(t_us_sparse_bcsr[-1]):8,}us {sm_csr.nse} nse')
del op
del t
except BaseException as e:
print(f' Sparse BCSR SDDMM failed with error: {e}')
t_us_sparse_bcsr.append(0)
# Sparse custom op for SDDMM on GPU
try:
op = lambda: sparse_custom_sddmm_jit(sm_csr, a, b).block_until_ready()
sparse_custom_sddmm_result_csr = op() # Warmup
# Verify that dense and sparse path compute the same result.
#assert jnp.allclose(dense_sddmm_result, sparse_custom_sddmm_result_csr.todense(), rtol=1e-2)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_custom_bcsr.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse GPU SDDMM completed in {int(t_us_sparse_custom_bcsr[-1]):8,}us {sm_csr.nse} nse')
del op
del t
except BaseException as e:
print(f' Sparse GPU SDDMM failed with error: {e}')
t_us_sparse_custom_bcsr.append(0)
print(sparse_custom_sddmm_result_csr.todense())
print()
density = [n_entries_sparse[i] / N_DENSE_ENTRIES for i in range(len(n_entries_sparse))]
plt.figure(figsize=(10,7))
plt.plot(density, t_us_dense, 'o-')
plt.plot(density, t_us_sparse_bcoo, 'o-')
plt.plot(density, t_us_sparse_bcsr, 'o-')
plt.plot(density, t_us_sparse_custom_bcsr, 'o-')
plt.title(f'{DENSE_SHAPE} SpMV time in microseconds vs density ({jax.devices()[0].platform}).');
plt.legend(['Dense SDDMM CPU (XLA NEXT)', 'SDDMM CPU (BCOO)', 'SDDMM CPU (BCSR)', 'SDDMM GPU (custom)']);
plt.xlabel('Fraction of non-zero entries');
plt.ylabel('microseconds');
plt.yscale('log');
[ 0] Fraction of non-zero entries: 0.00000 Dense ARRAY SDDMM completed in 2,910,600us Sparse BCOO SDDMM completed in 274us 0 nse Sparse BCSR SDDMM completed in 1,035us 0 nse Sparse GPU SDDMM completed in 2,684us 0 nse [ 1] Fraction of non-zero entries: 0.10000 Dense ARRAY SDDMM completed in 2,871,439us Sparse BCOO SDDMM completed in 159,393us 104638 nse Sparse BCSR SDDMM completed in 177,083us 104638 nse Sparse GPU SDDMM completed in 17,192us 104638 nse [ 2] Fraction of non-zero entries: 0.20000 Dense ARRAY SDDMM completed in 2,928,734us Sparse BCOO SDDMM completed in 219,174us 209835 nse Sparse BCSR SDDMM completed in 283,097us 209835 nse Sparse GPU SDDMM completed in 7,729us 209835 nse [ 3] Fraction of non-zero entries: 0.30000 Dense ARRAY SDDMM completed in 3,022,233us Sparse BCOO SDDMM completed in 328,827us 314145 nse Sparse BCSR SDDMM completed in 427,342us 314145 nse Sparse GPU SDDMM completed in 8,328us 314145 nse [ 4] Fraction of non-zero entries: 0.40000 Dense ARRAY SDDMM completed in 2,898,277us Sparse BCOO SDDMM completed in 422,196us 419046 nse Sparse BCSR SDDMM completed in 520,667us 419046 nse Sparse GPU SDDMM completed in 9,146us 419046 nse [ 5] Fraction of non-zero entries: 0.50000 Dense ARRAY SDDMM completed in 2,941,023us Sparse BCOO SDDMM completed in 519,081us 523768 nse Sparse BCSR SDDMM completed in 635,784us 523768 nse Sparse GPU SDDMM completed in 9,952us 523768 nse [ 6] Fraction of non-zero entries: 0.60000 Dense ARRAY SDDMM completed in 2,872,818us Sparse BCOO SDDMM completed in 623,089us 628597 nse Sparse BCSR SDDMM completed in 745,292us 628597 nse Sparse GPU SDDMM completed in 11,055us 628597 nse [ 7] Fraction of non-zero entries: 0.70000 Dense ARRAY SDDMM completed in 2,911,507us Sparse BCOO SDDMM completed in 726,227us 734097 nse Sparse BCSR SDDMM completed in 865,237us 734097 nse Sparse GPU SDDMM completed in 11,420us 734097 nse [ 8] Fraction of non-zero entries: 0.80000 Dense ARRAY SDDMM completed in 2,821,479us Sparse BCOO SDDMM completed in 846,639us 838722 nse Sparse BCSR SDDMM completed in 1,002,814us 838722 nse Sparse GPU SDDMM completed in 11,851us 838722 nse [ 9] Fraction of non-zero entries: 0.90000 Dense ARRAY SDDMM completed in 2,861,378us Sparse BCOO SDDMM completed in 948,079us 943895 nse Sparse BCSR SDDMM completed in 1,126,073us 943895 nse Sparse GPU SDDMM completed in 13,150us 943895 nse
Profile 2:4 on GPU
import matplotlib.pyplot as plt
import numpy as np
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
def matmul24(c, a, b):
return c + jnp.dot(a, b)
matmul24_jit = jax.jit(matmul24)
custom24_jit = sparse_jit(sparse_jax_2to4_spmm)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_sizes = []
t_us_dense = []
t_us_custom = []
x = 16
for i in range(8):
DENSE_SHAPE = (x, x)
n_sizes.append(DENSE_SHAPE[0] * DENSE_SHAPE[1])
s = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=0.5 # keep it close to 2:4
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE
)
)
a = jnp.asarray(s, dtype=np.float16)
#a = jnp.full(DENSE_SHAPE, 10.0, dtype=jnp.float16)
b = jnp.ones(DENSE_SHAPE, dtype=jnp.float16)
#b = jnp.identity(x, dtype=jnp.float16)
c = jnp.zeros(DENSE_SHAPE, dtype=jnp.float16)
x = x * 2
# Dense
try:
op = lambda: matmul24_jit(c, a, b).block_until_ready()
dense_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY MatMul24 completed in {int(t_us_dense[-1]):8,}us {DENSE_SHAPE}')
del op
del t
except BaseException as e:
print(f' Dense ARRAY MatMul24 failed with error: {e}')
t_us_dense.append(0)
# Custom op for 2:4
try:
op = lambda: custom24_jit(c, a, b).block_until_ready()
custom_result = op() # Warmup
# Note that pruned outcome is not same as dense.
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_custom.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse GPU NV2:4 completed in {int(t_us_custom[-1]):8,}us {DENSE_SHAPE}')
del op
del t
except BaseException as e:
print(f' Sparse GPU NV2:4 failed with error: {e}')
print(dense_result)
print(custom_result)
t_us_custom.append(0)
#print(dense_result)
#print(custom_result)
print()
plt.figure(figsize=(10,7))
plt.plot(n_sizes, t_us_dense, 'o-')
plt.plot(n_sizes, t_us_custom, 'o-')
plt.title(f'{DENSE_SHAPE} Matmul24 time in microseconds vs size ({jax.devices()[0].platform}).');
plt.legend(['Dense Matmul24 CPU (XLA NEXT)', 'Matmul24 GPU (custom)']);
plt.xlabel('size');
plt.ylabel('microseconds');
plt.yscale('log');
Dense ARRAY MatMul24 completed in 12us (16, 16) Sparse GPU NV2:4 completed in 1,555us (16, 16) Dense ARRAY MatMul24 completed in 41us (32, 32) Sparse GPU NV2:4 completed in 4,083us (32, 32) Dense ARRAY MatMul24 completed in 357us (64, 64) Sparse GPU NV2:4 completed in 3,351us (64, 64) Dense ARRAY MatMul24 completed in 3,003us (128, 128) Sparse GPU NV2:4 completed in 836us (128, 128) Dense ARRAY MatMul24 completed in 25,408us (256, 256) Sparse GPU NV2:4 completed in 1,465us (256, 256) Dense ARRAY MatMul24 completed in 254,398us (512, 512) Sparse GPU NV2:4 completed in 1,235us (512, 512) Dense ARRAY MatMul24 completed in 3,703,789us (1024, 1024) Sparse GPU NV2:4 completed in 4,868us (1024, 1024) Dense ARRAY MatMul24 completed in 33,101,322us (2048, 2048) Sparse GPU NV2:4 completed in 15,398us (2048, 2048)