Sandbox for porting Sparse Jax CPU Benchmark
See colab. Currently, it is only available to googlers.
This colab can be executed on DeepMind CPU runtime.
We need to use XLA-NEXT for fast sparse ops: the next cell shows how to set it up, and it allows to import sparse_jit either from P4HEAD (default) or from a CitC workspace.
- soft restart: go to Runtime, Restart Runtime
- hard restart: go to Connect to a Borg Runtime, Stop Runtime, Start New
Setup
import functools
import jax
# This is required to get Sparse JAX (experimental for now).
from jax.experimental import sparse
import jax.numpy as jnp
import os
from colabtools import adhoc_import
CITC_CLIENT = ""
CITC_USER = ""
CL_NUMBER = ""
# Use XLA-NEXT for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs.
# Eventually, it should be removed
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'
# Set up for benchmarking
import time
import numpy as np
import matplotlib.pyplot as plt
import timeit
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
def run(func, backend, mode, format, inputs, checksum):
start_time = time.time()
if backend == 'JAX':
compiled = jax.jit(sparse.sparsify(func)).lower(*inputs).compile()
else:
compiled = sparse_jit(func, output_format=format).lower(*inputs).compile()
compile_time = time.time() - start_time
if mode == 'compile':
return compile_time * SECONDS_TO_MICROS
assert mode == 'runtime'
exec_func = lambda: compiled(*inputs)
res = exec_func()
if backend == 'JAX':
assert jnp.isclose((res if isinstance(res, sparse.BCOO) else np.asarray(res)).sum(), checksum)
t = timeit.repeat(
exec_func,
repeat=REPEATS,
number=LOOPS,
)
return SECONDS_TO_MICROS * max(t) / LOOPS
formats = ['dense', 'sparse']
sparse_formats = ['bcoo', 'bcsr'] # fails for some tests
backends = ['JAX', 'MLIR']
modes = ['compile', 'runtime']
nses = [0.1, 0.2]
def gen_sparse_input(shape, fraction_non_zero, subkey):
return (
jax.random.bernoulli(
key=subkey,
shape=shape,
p=fraction_non_zero,
)
*
jax.random.normal(
key=subkey,
shape=shape,
)
)
def shapeStr(shape):
return 'x'.join(str(x) for x in shape)
# Nop
#
import pandas as pd
func = lambda tensor: tensor
for mode in modes:
all_times = []
for format in formats:
if format == 'dense':
mat = np.zeros((10, 10), dtype=float)
mat[0, 0] = 1.0
else:
data = jnp.array([1.0])
indices = jnp.array([[0, 0]])
mat = sparse.BCOO((data, indices), shape=(10, 10))
times = []
axis = []
for backend in backends:
times.append(run(func, backend, mode, format, [mat], 1.0))
axis.append(f'{backend}')
all_times.append(times)
df = pd.DataFrame(all_times, index=formats, columns=axis)
df.plot.bar(rot=0).set_title(f'{mode} time for JAX and MLIR sparse backend')
# Elementwise add
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: mat1 + mat2
all_times = {}
for mode in modes:
all_times[mode] = {}
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
all_times[mode][(backend, format)] = []
shapes = []
for shape in [[2**i, 2**i] for i in range(7, 11)]:
for nse in nses:
shapes.append(f'{shapeStr(shape)}\nnse={nse}')
mat1_d = gen_sparse_input(shape, nse, subkey)
mat2_d = mat1_d.T
checksum = func(mat1_d, mat2_d).sum()
for format in sparse_formats:
if format == 'bcoo':
mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
else:
assert format == 'bcsr'
mat1 = jax.experimental.sparse.BCSR.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCSR.fromdense(mat2_d)
for mode in modes:
for backend in backends:
if (backend == 'JAX' and format == 'bcsr'): continue
times = all_times[mode][(backend, format)]
times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
del mat1, mat2
del mat1_d, mat2_d
for mode in modes:
fig = plt.figure()
plt.subplots_adjust(bottom = 0.10)
plt.title(f'{mode} time for JAX and MLIR sparse backend')
plt.xlabel('Parameters');
plt.ylabel('Elementwise Add microseconds');
if mode == 'runtime':
plt.yscale('log');
legends = []
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
# Draw the time for the (backend, format).
times = all_times[mode][(backend, format)]
plt.plot(shapes, times, 'o-')
legends.append(f'{backend}_{format}')
plt.legend(legends)
plt.show()
# Elementwise mul
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: mat1 * mat2
all_times = {}
for mode in modes:
all_times[mode] = {}
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
all_times[mode][(backend, format)] = []
shapes = []
for shape in [[2**i, 2**i] for i in range(7, 9)]: # 11 crash the runtime
for nse in nses:
shapes.append(f'{shapeStr(shape)}\nnse={nse}')
mat1_d = gen_sparse_input(shape, nse, subkey)
mat2_d = mat1_d.T
checksum = func(mat1_d, mat2_d).sum()
for format in sparse_formats:
if format == 'bcoo':
mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
else:
assert format == 'bcsr'
mat1 = jax.experimental.sparse.BCSR.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCSR.fromdense(mat2_d)
for mode in modes:
for backend in backends:
if (backend == 'JAX' and format == 'bcsr'): continue
times = all_times[mode][(backend, format)]
times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
del mat1, mat2
del mat1_d, mat2_d
for mode in modes:
fig = plt.figure()
plt.subplots_adjust(bottom = 0.10)
plt.title(f'{mode} time for JAX and MLIR sparse backend')
plt.xlabel('Parameters');
plt.ylabel('Elementwise Mul microseconds');
if mode == 'runtime':
plt.yscale('log');
legends = []
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
# Draw the time for the (backend, format).
times = all_times[mode][(backend, format)]
plt.plot(shapes, times, 'o-')
legends.append(f'{backend}_{format}')
plt.legend(legends)
plt.show()
# Mat @ Vec
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat, vec: mat @ vec
all_times = {}
for mode in modes:
all_times[mode] = {}
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
all_times[mode][(backend, format)] = []
shapes = []
for shape in [[2**i, 2**i] for i in range(7, 11)]:
for nse in nses:
shapes.append(f'{shapeStr(shape)}\nnse={nse}')
mat_d = gen_sparse_input(shape, nse, subkey)
vec_d = gen_sparse_input([shape[1]], nse, subkey)
checksum = func(mat_d, vec_d).sum()
for format in sparse_formats:
if format == 'bcoo':
mat = jax.experimental.sparse.BCOO.fromdense(mat_d)
else:
assert format == 'bcsr'
mat = jax.experimental.sparse.BCSR.fromdense(mat_d)
vec = jax.experimental.sparse.BCOO.fromdense(vec_d)
for mode in modes:
for backend in backends:
if (backend == 'JAX' and format == 'bcsr'): continue
times = all_times[mode][(backend, format)]
times.append(run(func, backend, mode, 'sparse', [mat, vec], checksum))
del mat, vec
del mat_d, vec_d
for mode in modes:
fig = plt.figure()
plt.subplots_adjust(bottom = 0.10)
plt.title(f'{mode} time for JAX and MLIR sparse backend')
plt.xlabel('Parameters');
plt.ylabel('Mat @ Vec Mul microseconds');
if mode == 'runtime':
plt.yscale('log');
legends = []
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
# Draw the time for the (backend, format).
times = all_times[mode][(backend, format)]
plt.plot(shapes, times, 'o-')
legends.append(f'{backend}_{format}')
plt.legend(legends)
plt.show()
# Mat @ mat
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: mat1 @ mat2
all_times = {}
for mode in modes:
all_times[mode] = {}
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
all_times[mode][(backend, format)] = []
shapes = []
for shape in [[2**i, 2**i] for i in range(7, 8)]: # 11 crash the runtime
for nse in nses:
shapes.append(f'{shapeStr(shape)}\nnse={nse}')
mat1_d = gen_sparse_input(shape, nse, subkey)
mat2_d = mat1_d.T
checksum = func(mat1_d, mat2_d).sum()
for format in sparse_formats:
if format == 'bcoo':
mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
else:
assert format == 'bcsr'
mat1 = jax.experimental.sparse.BCSR.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCSR.fromdense(mat2_d)
for mode in modes:
for backend in backends:
if (backend == 'JAX' and format == 'bcsr'): continue
times = all_times[mode][(backend, format)]
times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
del mat1, mat2
del mat1_d, mat2_d
for mode in modes:
fig = plt.figure()
plt.subplots_adjust(bottom = 0.10)
plt.title(f'{mode} time for JAX and MLIR sparse backend')
plt.xlabel('Parameters');
plt.ylabel('Mat @ Mat microseconds');
if mode == 'runtime':
plt.yscale('log');
legends = []
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
# Draw the time for the (backend, format).
times = all_times[mode][(backend, format)]
plt.plot(shapes, times, 'o-')
legends.append(f'{backend}_{format}')
plt.legend(legends)
plt.show()
# SDDMM - sample dense-dense matmul
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda sample, mat1, mat2: sample * (mat1 @ mat2)
all_shapes = [([8, 32], [8, 16], [16, 32]), ([1024, 1024], [1024, 1024],[1024, 1024])]
all_times = {}
for mode in modes:
all_times[mode] = {}
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
all_times[mode][(backend, format)] = []
shapes = []
for s_shape, l_shape, r_shape in all_shapes:
for nse in nses:
shapes.append(f'{str(s_shape)}\n{str(l_shape)}\n{str(r_shape)}\nnse={nse}')
sample_d = gen_sparse_input(s_shape, nse, subkey)
mat1 = gen_sparse_input(l_shape, nse, subkey)
mat2 = gen_sparse_input(r_shape, nse, subkey)
checksum = func(sample_d, mat1, mat2).sum()
for format in sparse_formats:
if format == 'bcoo':
sample = jax.experimental.sparse.BCOO.fromdense(sample_d)
else:
assert format == 'bcsr'
sample = jax.experimental.sparse.BCSR.fromdense(sample_d)
for mode in modes:
for backend in backends:
if (backend == 'JAX' and format == 'bcsr'): continue
times = all_times[mode][(backend, format)]
times.append(run(func, backend, mode, 'sparse', [sample, mat1, mat2], checksum))
del sample
del sample_d, mat1, mat2
for mode in modes:
fig = plt.figure()
plt.subplots_adjust(bottom = 0.10)
plt.title(f'{mode} time for JAX and MLIR sparse backend')
plt.xlabel('Parameters');
plt.ylabel('SDDMM microseconds');
if mode == 'runtime':
plt.yscale('log');
legends = []
for backend in backends:
for format in sparse_formats:
if (backend == 'JAX' and format == 'bcsr'): continue
# Draw the time for the (backend, format).
times = all_times[mode][(backend, format)]
plt.plot(shapes, times, 'o-')
legends.append(f'{backend}_{format}')
plt.legend(legends)
plt.show()
# Einsum
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: jnp.einsum('ijk,jil->kl', mat1, mat2)
all_shapes = [([8, 1, 16], [1, 8, 4]), ([32, 2, 128], [2, 32, 64])]
all_times = {}
format = 'bcoo'
for mode in modes:
all_times[mode] = {}
for backend in backends:
all_times[mode][(backend, format)] = []
shapes = []
for l_shape, r_shape in all_shapes:
for nse in nses:
shapes.append(f'{shapeStr(l_shape)}\n{shapeStr(r_shape)}\n(nse={nse})')
mat1_d = gen_sparse_input(l_shape, nse, subkey)
mat2_d = gen_sparse_input(r_shape, nse, subkey)
checksum = func(mat1_d, mat2_d).sum()
mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
for mode in modes:
for backend in backends:
if (backend == 'JAX' and format == 'bcsr'): continue
times = all_times[mode][(backend, format)]
times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
del mat1, mat2
del mat1_d, mat2_d
for mode in modes:
fig = plt.figure()
plt.subplots_adjust(bottom = 0.10)
plt.title(f'{mode} time for JAX and MLIR sparse backend')
plt.xlabel('Parameters');
plt.ylabel('Einsum(a, b, ijk,jil->kl) microseconds');
if mode == 'runtime':
plt.yscale('log');
legends = []
for backend in backends:
# Draw the time for the (backend, format).
times = all_times[mode][(backend, format)]
plt.plot(shapes, times, 'o-')
legends.append(f'{backend}_{format}')
plt.legend(legends)
plt.show()