Sandbox for playing with Sparse JAX and the sparse compiler on GPU.
See colab. Currently, it is only available to googlers.
This colab can be executed on a special XLA-NEXT:CPU+ enabled GPU runtime.
The XLA-NEXT:CPU+ compiler (the plus for GPU!) is needed to accelerate sparse operations on the GPU. The obtain this compiler requires running from a custom-built colab binary, since the "GPU part" of XLA-NEXT is not in production.
Use the cell below to select between CUDA libgen and CUDA codegen for sparsified code (toggling this always needs a hard reset of the runtime):
- xla_cpu_sparse_cuda_threads=1 CUDA LIBGEN, e.g. use cuSPARSE
- xla_cpu_sparse_cuda_threads>1 CUDA CODEGEN, e.g. use cuda threads
Setup
import functools
import jax
from jax.experimental import sparse # this gives you sparse JAX
import jax.numpy as jnp
import os
from colabtools import adhoc_import
# Use XLA-NEXT:CPU+ for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs
# --xla_cpu_sparse_cuda_threads enables the GPU part of XLA-NEXT:CPU+
#
# Here, the cuda threads value further controls the following:
# xla_cpu_sparse_cuda_threads=1 CUDA LIBGEN, e.g. use cuSPARSE
# xla_cpu_sparse_cuda_threads>1 CUDA CODEGEN, e.g. use cuda threads
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false --xla_cpu_sparse_cuda_threads=1'
Importing modules from P4 HEAD.
Profile SpMV on GPU
import matplotlib.pyplot as plt
import numpy as np
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
DENSE_SHAPE = (16*1024, 16*1024)
VSHAPE = (16*1024,)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def matvec(m, v):
return jnp.dot(m, v)
matvec_jit = jax.jit(matvec)
sparse_matvec_jit = sparse_jit(matvec)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_entries_sparse = []
t_us_dense = []
t_us_sparse_bcoo = [] # NOTE COO AoS deprecated in cuSPARSE 11.2 (falls back to sparse CPU)
t_us_sparse_bcsr = []
for i in range(10):
fraction_non_zero = i / 100.0
m = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE
)
)
v = jnp.ones(VSHAPE)
sm_coo = jax.experimental.sparse.BCOO.fromdense(m)
sm_csr = jax.experimental.sparse.BCSR.fromdense(m)
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
# Dense
try:
op = lambda: matvec_jit(m, v).block_until_ready()
dense_matvec_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY matvec completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense ARRAY matvec failed with error: {e}')
# Sparse BCOO
try:
n_entries_sparse.append(sm_coo.nse)
op = lambda: sparse_matvec_jit(sm_coo, v).block_until_ready()
sparse_matvec_result_coo = op() # Warmup
# Verify that dense and sparse path compute the same result.
assert jnp.allclose(dense_matvec_result, sparse_matvec_result_coo, rtol=1e-2)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcoo.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCOO SpMV completed in {int(t_us_sparse_bcoo[-1]):8,}us {sm_coo.nse} nse')
del op
del t
except BaseException as e:
print(f' Sparse BCOO SpMV failed with error: {e}')
t_us_sparse_bcoo.append(0)
print(dense_matvec_result)
print(sparse_matvec_result_coo)
# Sparse BCSR
try:
op = lambda: sparse_matvec_jit(sm_csr, v).block_until_ready()
sparse_matvec_result_csr = op() # Warmup
# Verify that dense and sparse path compute the same result.
assert jnp.allclose(dense_matvec_result, sparse_matvec_result_csr, rtol=1e-2)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcsr.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCSR SpMV completed in {int(t_us_sparse_bcsr[-1]):8,}us {sm_csr.nse} nse')
del op
del t
except BaseException as e:
print(f' Sparse BCSR SpMV failed with error: {e}')
t_us_sparse_bcsr.append(0)
print(dense_matvec_result)
print(sparse_matvec_result_csr)
print()
# 16K^2 CPU/GPU, 10 data points
t_us_gpu_dense = [1302.6769971475005, 1307.0351909846067, 1294.1299006342888, 1302.3867970332503, 1299.345400184393, 1298.9948969334364, 1306.315790861845, 1285.1458974182606, 1315.4757907614112, 1319.497195072472]
t_us_cpu_bcsr = [30.583798070438206, 2938.288199948147, 6361.140398075804, 9582.061402034014, 12796.884999261238, 15989.26380102057, 19310.838100500405, 22637.716299504973, 25892.800599103794, 28870.801199809648]
density = [n_entries_sparse[i] / N_DENSE_ENTRIES for i in range(len(n_entries_sparse))]
plt.figure(figsize=(10,7))
plt.plot(density, t_us_dense, 'o-')
plt.plot(density, t_us_gpu_dense, 'o-')
plt.plot(density, t_us_cpu_bcsr, 'o-')
plt.plot(density, t_us_sparse_bcoo, 'o-')
plt.plot(density, t_us_sparse_bcsr, 'o-')
plt.title(f'{DENSE_SHAPE} SpMV time in microseconds vs density ({jax.devices()[0].platform}).');
plt.legend(['Dense matvec CPU (XLA NEXT)', 'Dense matvec GPU (XLA GPU)', 'SpMV CPU (BCSR)', 'SpMV GPU (cuSparse COO)', 'SpMV GPU (cuSparse CSR)']);
plt.xlabel('Fraction of non-zero entries');
plt.ylabel('microseconds');
plt.yscale('log');
[ 0] Fraction of non-zero entries: 0.00000 Dense ARRAY matvec completed in 304,892us Sparse BCOO SpMV completed in 26us 0 nse Sparse BCSR SpMV completed in 440us 0 nse [ 1] Fraction of non-zero entries: 0.01000 Dense ARRAY matvec completed in 300,976us Sparse BCOO SpMV completed in 4,703us 2683428 nse Sparse BCSR SpMV completed in 9,363us 2683428 nse [ 2] Fraction of non-zero entries: 0.02000 Dense ARRAY matvec completed in 301,270us Sparse BCOO SpMV completed in 10,027us 5367424 nse Sparse BCSR SpMV completed in 10,465us 5367424 nse [ 3] Fraction of non-zero entries: 0.03000 Dense ARRAY matvec completed in 300,594us Sparse BCOO SpMV completed in 15,497us 8050085 nse Sparse BCSR SpMV completed in 15,433us 8050085 nse [ 4] Fraction of non-zero entries: 0.04000 Dense ARRAY matvec completed in 300,976us Sparse BCOO SpMV completed in 20,903us 10733923 nse Sparse BCSR SpMV completed in 20,344us 10733923 nse [ 5] Fraction of non-zero entries: 0.05000 Dense ARRAY matvec completed in 300,844us Sparse BCOO SpMV completed in 26,383us 13419841 nse Sparse BCSR SpMV completed in 24,971us 13419841 nse [ 6] Fraction of non-zero entries: 0.06000 Dense ARRAY matvec completed in 302,289us Sparse BCOO SpMV completed in 31,947us 16104215 nse Sparse BCSR SpMV completed in 29,716us 16104215 nse [ 7] Fraction of non-zero entries: 0.07000 Dense ARRAY matvec completed in 300,723us Sparse BCOO SpMV completed in 38,010us 18790488 nse Sparse BCSR SpMV completed in 35,196us 18790488 nse [ 8] Fraction of non-zero entries: 0.08000 Dense ARRAY matvec completed in 301,040us Sparse BCOO SpMV completed in 43,050us 21472986 nse Sparse BCSR SpMV completed in 40,224us 21472986 nse [ 9] Fraction of non-zero entries: 0.09000 Dense ARRAY matvec completed in 300,776us Sparse BCOO SpMV completed in 48,697us 24156594 nse Sparse BCSR SpMV completed in 45,140us 24156594 nse
Profile SpMM on GPU
import matplotlib.pyplot as plt
import numpy as np
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
#DENSE_SHAPE = (128, 128)
DENSE_SHAPE = (1024, 1024)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def matmul(a, b):
return jnp.dot(a, b)
matmul_jit = jax.jit(matmul)
sparse_matmul_jit = sparse_jit(matmul)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_entries_sparse = []
t_us_dense = []
t_us_sparse_bcsr = []
for i in range(10):
fraction_non_zero = i / 100.0
m = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE
)
)
mT = m.T
sm_csr = jax.experimental.sparse.BCSR.fromdense(m)
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
# Dense
try:
op = lambda: matmul_jit(m, mT).block_until_ready()
dense_matmul_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY matmul completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense ARRAY matmul failed with error: {e}')
# Sparse BCSR
try:
n_entries_sparse.append(sm_csr.nse)
op = lambda: sparse_matmul_jit(sm_csr, mT).block_until_ready()
sparse_matmul_result_csr = op() # Warmup
# Verify that dense and sparse path compute the same result.
assert jnp.allclose(dense_matmul_result, sparse_matmul_result_csr, rtol=1e-2)
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_sparse_bcsr.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Sparse BCSR SpMM completed in {int(t_us_sparse_bcsr[-1]):8,}us {sm_csr.nse} nse')
del op
del t
except BaseException as e:
print(f' Sparse BCSR SpMM failed with error: {e}')
t_us_sparse_bcsr.append(0)
print(dense_matmul_result)
print(sparse_matmul_result_csr)
print()
# 1K^2 CPU/GPU, 10 data points
t_us_gpu_dense = [248.88459593057632, 262.53070682287216, 240.3387101367116, 291.3612173870206, 277.50979643315077, 264.9290021508932, 245.88459637016058, 243.75307839363813, 234.11749862134457, 240.35531096160412]
t_us_cpu_bcsr = [283.7551000993699, 6258.42199951876, 12172.323698177934, 18268.319498747587, 24266.25629886985, 30276.235000928864, 36314.31639951188, 42213.035697932355, 48481.74570070114, 54309.82770049013]
density = [n_entries_sparse[i] / N_DENSE_ENTRIES for i in range(len(n_entries_sparse))]
plt.figure(figsize=(10,7))
plt.plot(density, t_us_dense, 'o-')
plt.plot(density, t_us_gpu_dense, 'o-')
plt.plot(density, t_us_cpu_bcsr, 'o-')
plt.plot(density, t_us_sparse_bcsr, 'o-')
plt.title(f'{DENSE_SHAPE} SpMM time in microseconds vs density ({jax.devices()[0].platform}).');
plt.legend(['Dense matmul CPU (XLA Next)', 'Dense matmul GPU (XLA GPU)', 'SpMM CPU (BCSR)', 'SpMM GPU (cuSparse CSR)']);
plt.xlabel('Fraction of non-zero entries');
plt.ylabel('microseconds');
plt.yscale('log');
[ 0] Fraction of non-zero entries: 0.00000 Dense ARRAY matmul completed in 3,090,630us Sparse BCSR SpMM completed in 3,286us 0 nse [ 1] Fraction of non-zero entries: 0.01000 Dense ARRAY matmul completed in 2,911,031us Sparse BCSR SpMM completed in 3,339us 10461 nse [ 2] Fraction of non-zero entries: 0.02000 Dense ARRAY matmul completed in 2,987,564us Sparse BCSR SpMM completed in 3,507us 20727 nse [ 3] Fraction of non-zero entries: 0.03000 Dense ARRAY matmul completed in 2,959,827us Sparse BCSR SpMM completed in 3,682us 31319 nse [ 4] Fraction of non-zero entries: 0.04000 Dense ARRAY matmul completed in 2,987,060us Sparse BCSR SpMM completed in 4,288us 41781 nse [ 5] Fraction of non-zero entries: 0.05000 Dense ARRAY matmul completed in 3,000,965us Sparse BCSR SpMM completed in 3,902us 52246 nse [ 6] Fraction of non-zero entries: 0.06000 Dense ARRAY matmul completed in 3,001,546us Sparse BCSR SpMM completed in 4,162us 62740 nse [ 7] Fraction of non-zero entries: 0.07000 Dense ARRAY matmul completed in 2,950,979us Sparse BCSR SpMM completed in 4,341us 73046 nse [ 8] Fraction of non-zero entries: 0.08000 Dense ARRAY matmul completed in 2,883,313us Sparse BCSR SpMM completed in 4,383us 83590 nse [ 9] Fraction of non-zero entries: 0.09000 Dense ARRAY matmul completed in 2,876,997us Sparse BCSR SpMM completed in 4,621us 94122 nse
Profile MV on CPU/GPU
import numpy as np
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
DENSE_SHAPE = (16*1024, 16*1024)
VSHAPE = (16*1024,)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def matvec(m, v):
return jnp.dot(m, v)
matvec_jit = jax.jit(matvec)
sparse_matvec_jit = sparse_jit(matvec)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
n_entries_sparse = []
t_us_dense = []
t_us_sparse_bcoo = []
t_us_sparse_bcsr = []
for i in range(10):
fraction_non_zero = i / 100.0
m = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE
)
)
v = jnp.ones(VSHAPE)
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
# Dense
try:
op = lambda: matvec_jit(m, v).block_until_ready()
dense_matvec_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY matvec completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense ARRAY matvec failed with error: {e}')
print(DENSE_SHAPE)
print(t_us_dense)
[ 0] Fraction of non-zero entries: 0.00000 Dense ARRAY matvec completed in 1,311us [ 1] Fraction of non-zero entries: 0.01000 Dense ARRAY matvec completed in 1,298us [ 2] Fraction of non-zero entries: 0.02000 Dense ARRAY matvec completed in 1,330us [ 3] Fraction of non-zero entries: 0.03000 Dense ARRAY matvec completed in 1,321us [ 4] Fraction of non-zero entries: 0.04000 Dense ARRAY matvec completed in 1,317us [ 5] Fraction of non-zero entries: 0.05000 Dense ARRAY matvec completed in 1,331us [ 6] Fraction of non-zero entries: 0.06000 Dense ARRAY matvec completed in 1,354us [ 7] Fraction of non-zero entries: 0.07000 Dense ARRAY matvec completed in 1,303us [ 8] Fraction of non-zero entries: 0.08000 Dense ARRAY matvec completed in 1,308us [ 9] Fraction of non-zero entries: 0.09000 Dense ARRAY matvec completed in 1,292us (16384, 16384) [1311.2015090882778, 1298.0406172573566, 1330.2603038027883, 1321.4902952313423, 1317.8099179640412, 1331.6298834979534, 1354.8784190788865, 1303.788903169334, 1308.6162973195314, 1292.1323999762535]
Profile GEMM on CPU/GPU
import numpy as np
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
DENSE_SHAPE = (1024, 1024)
N_DENSE_ENTRIES = DENSE_SHAPE[0] * DENSE_SHAPE[1]
def matmul(a, b):
return jnp.dot(a, b)
matmul_jit = jax.jit(matmul)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
t_us_dense = []
for i in range(10):
fraction_non_zero = i / 100.0
m = (
jax.random.bernoulli(
key=subkey,
shape=DENSE_SHAPE,
p=fraction_non_zero
)
*
jax.random.normal(
key=subkey,
shape=DENSE_SHAPE
)
)
mT = m.T
print(f'[{i:2}] Fraction of non-zero entries: {fraction_non_zero:.5f}')
# Dense
try:
op = lambda: matmul_jit(m, mT).block_until_ready()
dense_matmul_result = op() # Warmup
t = timeit.repeat(
op,
repeat=REPEATS,
number=LOOPS,
)
t_us_dense.append(SECONDS_TO_MICROS * max(t) / LOOPS)
print(f' Dense ARRAY matmul completed in {int(t_us_dense[-1]):8,}us')
del op
del t
except BaseException as e:
print(f' Dense ARRAY matmul failed with error: {e}')
print(DENSE_SHAPE)
print(t_us_dense)
[ 0] Fraction of non-zero entries: 0.00000 Dense ARRAY matmul completed in 248us [ 1] Fraction of non-zero entries: 0.01000 Dense ARRAY matmul completed in 262us [ 2] Fraction of non-zero entries: 0.02000 Dense ARRAY matmul completed in 240us [ 3] Fraction of non-zero entries: 0.03000 Dense ARRAY matmul completed in 291us [ 4] Fraction of non-zero entries: 0.04000 Dense ARRAY matmul completed in 277us [ 5] Fraction of non-zero entries: 0.05000 Dense ARRAY matmul completed in 264us [ 6] Fraction of non-zero entries: 0.06000 Dense ARRAY matmul completed in 245us [ 7] Fraction of non-zero entries: 0.07000 Dense ARRAY matmul completed in 243us [ 8] Fraction of non-zero entries: 0.08000 Dense ARRAY matmul completed in 234us [ 9] Fraction of non-zero entries: 0.09000 Dense ARRAY matmul completed in 240us (1024, 1024) [248.88459593057632, 262.53070682287216, 240.3387101367116, 291.3612173870206, 277.50979643315077, 264.9290021508932, 245.88459637016058, 243.75307839363813, 234.11749862134457, 240.35531096160412]