Sandbox for porting Sparse Jax CPU Benchmark

See colab. Currently, it is only available to googlers.


This colab can be executed on DeepMind CPU runtime.

We need to use XLA-NEXT for fast sparse ops: the next cell shows how to set it up, and it allows to import sparse_jit either from P4HEAD (default) or from a CitC workspace.

  • soft restart: go to Runtime, Restart Runtime
  • hard restart: go to Connect to a Borg Runtime, Stop Runtime, Start New

Setup

# Set up for benchmarking
import time
import numpy as np
import matplotlib.pyplot as plt
import timeit

REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
def run(func, backend, mode, format, inputs, checksum):
  start_time = time.time()
  if backend == 'JAX':
    compiled = jax.jit(sparse.sparsify(func)).lower(*inputs).compile()
  else:
    compiled = sparse_jit(func, output_format=format).lower(*inputs).compile()
  compile_time = time.time() - start_time
  if mode == 'compile':
    return compile_time * SECONDS_TO_MICROS


  assert mode == 'runtime'
  exec_func = lambda: compiled(*inputs)
  res = exec_func()
  if backend == 'JAX':
    assert jnp.isclose((res if isinstance(res, sparse.BCOO) else np.asarray(res)).sum(), checksum)

  t = timeit.repeat(
        exec_func,
        repeat=REPEATS,
        number=LOOPS,
    )
  return SECONDS_TO_MICROS * max(t) / LOOPS

formats = ['dense', 'sparse']
sparse_formats = ['bcoo', 'bcsr'] # fails for some tests
backends = ['JAX', 'MLIR']
modes = ['compile', 'runtime']
nses = [0.1, 0.2]

def gen_sparse_input(shape, fraction_non_zero, subkey):
  return (
    jax.random.bernoulli(
        key=subkey,
        shape=shape,
        p=fraction_non_zero,
    )
    *
    jax.random.normal(
      key=subkey,
      shape=shape,
    )
  )

def shapeStr(shape):
  return 'x'.join(str(x) for x in shape)
# Nop
#
import pandas as pd
func = lambda tensor: tensor
for mode in modes:
  all_times = []
  for format in formats:
    if format == 'dense':
      mat = np.zeros((10, 10), dtype=float)
      mat[0, 0] = 1.0
    else:
      data = jnp.array([1.0])
      indices = jnp.array([[0, 0]])
      mat = sparse.BCOO((data, indices), shape=(10, 10))

    times = []
    axis = []
    for backend in backends:
      times.append(run(func, backend, mode, format, [mat], 1.0))
      axis.append(f'{backend}')
    all_times.append(times)

  df = pd.DataFrame(all_times, index=formats, columns=axis)
  df.plot.bar(rot=0).set_title(f'{mode} time for JAX and MLIR sparse backend')

png

png

# Elementwise add
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: mat1 + mat2
all_times = {}
for mode in modes:
  all_times[mode] = {}
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      all_times[mode][(backend, format)] = []

shapes = []
for shape in [[2**i, 2**i] for i in range(7, 11)]:
  for nse in nses:
    shapes.append(f'{shapeStr(shape)}\nnse={nse}')

    mat1_d = gen_sparse_input(shape, nse, subkey)
    mat2_d = mat1_d.T
    checksum = func(mat1_d, mat2_d).sum()
    for format in sparse_formats:
      if format == 'bcoo':
        mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
        mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
      else:
        assert format == 'bcsr'
        mat1 = jax.experimental.sparse.BCSR.fromdense(mat1_d)
        mat2 = jax.experimental.sparse.BCSR.fromdense(mat2_d)

      for mode in modes:
        for backend in backends:
          if (backend == 'JAX' and format == 'bcsr'): continue
          times = all_times[mode][(backend, format)]
          times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
      del mat1, mat2
    del mat1_d, mat2_d

for mode in modes:
  fig = plt.figure()
  plt.subplots_adjust(bottom = 0.10)
  plt.title(f'{mode} time for JAX and MLIR sparse backend')
  plt.xlabel('Parameters');
  plt.ylabel('Elementwise Add microseconds');
  if mode == 'runtime':
    plt.yscale('log');
  legends = []
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      # Draw the time for the (backend, format).
      times = all_times[mode][(backend, format)]
      plt.plot(shapes, times, 'o-')
      legends.append(f'{backend}_{format}')
  plt.legend(legends)
  plt.show()

png

png

# Elementwise mul
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: mat1 * mat2
all_times = {}
for mode in modes:
  all_times[mode] = {}
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      all_times[mode][(backend, format)] = []

shapes = []
for shape in [[2**i, 2**i] for i in range(7, 9)]: # 11 crash the runtime
  for nse in nses:
    shapes.append(f'{shapeStr(shape)}\nnse={nse}')

    mat1_d = gen_sparse_input(shape, nse, subkey)
    mat2_d = mat1_d.T
    checksum = func(mat1_d, mat2_d).sum()
    for format in sparse_formats:
      if format == 'bcoo':
        mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
        mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
      else:
        assert format == 'bcsr'
        mat1 = jax.experimental.sparse.BCSR.fromdense(mat1_d)
        mat2 = jax.experimental.sparse.BCSR.fromdense(mat2_d)

      for mode in modes:
        for backend in backends:
          if (backend == 'JAX' and format == 'bcsr'): continue
          times = all_times[mode][(backend, format)]
          times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
      del mat1, mat2
    del mat1_d, mat2_d

for mode in modes:
  fig = plt.figure()
  plt.subplots_adjust(bottom = 0.10)
  plt.title(f'{mode} time for JAX and MLIR sparse backend')
  plt.xlabel('Parameters');
  plt.ylabel('Elementwise Mul microseconds');
  if mode == 'runtime':
    plt.yscale('log');
  legends = []
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      # Draw the time for the (backend, format).
      times = all_times[mode][(backend, format)]
      plt.plot(shapes, times, 'o-')
      legends.append(f'{backend}_{format}')
  plt.legend(legends)
  plt.show()

png

png

# Mat @ Vec
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat, vec: mat @ vec
all_times = {}
for mode in modes:
  all_times[mode] = {}
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      all_times[mode][(backend, format)] = []

shapes = []
for shape in [[2**i, 2**i] for i in range(7, 11)]:
  for nse in nses:
    shapes.append(f'{shapeStr(shape)}\nnse={nse}')

    mat_d = gen_sparse_input(shape, nse, subkey)
    vec_d = gen_sparse_input([shape[1]], nse, subkey)
    checksum = func(mat_d, vec_d).sum()
    for format in sparse_formats:
      if format == 'bcoo':
        mat = jax.experimental.sparse.BCOO.fromdense(mat_d)
      else:
        assert format == 'bcsr'
        mat = jax.experimental.sparse.BCSR.fromdense(mat_d)
      vec = jax.experimental.sparse.BCOO.fromdense(vec_d)

      for mode in modes:
        for backend in backends:
          if (backend == 'JAX' and format == 'bcsr'): continue
          times = all_times[mode][(backend, format)]
          times.append(run(func, backend, mode, 'sparse', [mat, vec], checksum))
      del mat, vec
    del mat_d, vec_d

for mode in modes:
  fig = plt.figure()
  plt.subplots_adjust(bottom = 0.10)
  plt.title(f'{mode} time for JAX and MLIR sparse backend')
  plt.xlabel('Parameters');
  plt.ylabel('Mat @ Vec Mul microseconds');
  if mode == 'runtime':
    plt.yscale('log');
  legends = []
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      # Draw the time for the (backend, format).
      times = all_times[mode][(backend, format)]
      plt.plot(shapes, times, 'o-')
      legends.append(f'{backend}_{format}')
  plt.legend(legends)
  plt.show()

png

png

# Mat @ mat
#

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: mat1 @ mat2
all_times = {}
for mode in modes:
  all_times[mode] = {}
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      all_times[mode][(backend, format)] = []

shapes = []
for shape in [[2**i, 2**i] for i in range(7, 8)]: # 11 crash the runtime
  for nse in nses:
    shapes.append(f'{shapeStr(shape)}\nnse={nse}')

    mat1_d = gen_sparse_input(shape, nse, subkey)
    mat2_d = mat1_d.T
    checksum = func(mat1_d, mat2_d).sum()
    for format in sparse_formats:
      if format == 'bcoo':
        mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
        mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)
      else:
        assert format == 'bcsr'
        mat1 = jax.experimental.sparse.BCSR.fromdense(mat1_d)
        mat2 = jax.experimental.sparse.BCSR.fromdense(mat2_d)

      for mode in modes:
        for backend in backends:
          if (backend == 'JAX' and format == 'bcsr'): continue
          times = all_times[mode][(backend, format)]
          times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
      del mat1, mat2
    del mat1_d, mat2_d

for mode in modes:
  fig = plt.figure()
  plt.subplots_adjust(bottom = 0.10)
  plt.title(f'{mode} time for JAX and MLIR sparse backend')
  plt.xlabel('Parameters');
  plt.ylabel('Mat @ Mat microseconds');
  if mode == 'runtime':
    plt.yscale('log');
  legends = []
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      # Draw the time for the (backend, format).
      times = all_times[mode][(backend, format)]
      plt.plot(shapes, times, 'o-')
      legends.append(f'{backend}_{format}')
  plt.legend(legends)
  plt.show()

png

png

# SDDMM - sample dense-dense matmul
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda sample, mat1, mat2: sample * (mat1 @ mat2)
all_shapes = [([8, 32], [8, 16], [16, 32]), ([1024, 1024], [1024, 1024],[1024, 1024])]
all_times = {}
for mode in modes:
  all_times[mode] = {}
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      all_times[mode][(backend, format)] = []

shapes = []
for s_shape, l_shape, r_shape in all_shapes:
  for nse in nses:
    shapes.append(f'{str(s_shape)}\n{str(l_shape)}\n{str(r_shape)}\nnse={nse}')

    sample_d = gen_sparse_input(s_shape, nse, subkey)
    mat1 = gen_sparse_input(l_shape, nse, subkey)
    mat2 = gen_sparse_input(r_shape, nse, subkey)
    checksum = func(sample_d, mat1, mat2).sum()

    for format in sparse_formats:
      if format == 'bcoo':
        sample = jax.experimental.sparse.BCOO.fromdense(sample_d)
      else:
        assert format == 'bcsr'
        sample = jax.experimental.sparse.BCSR.fromdense(sample_d)

      for mode in modes:
        for backend in backends:
          if (backend == 'JAX' and format == 'bcsr'): continue
          times = all_times[mode][(backend, format)]
          times.append(run(func, backend, mode, 'sparse', [sample, mat1, mat2], checksum))
      del sample
    del sample_d, mat1, mat2

for mode in modes:
  fig = plt.figure()
  plt.subplots_adjust(bottom = 0.10)
  plt.title(f'{mode} time for JAX and MLIR sparse backend')
  plt.xlabel('Parameters');
  plt.ylabel('SDDMM microseconds');
  if mode == 'runtime':
    plt.yscale('log');
  legends = []
  for backend in backends:
    for format in sparse_formats:
      if (backend == 'JAX' and format == 'bcsr'): continue
      # Draw the time for the (backend, format).
      times = all_times[mode][(backend, format)]
      plt.plot(shapes, times, 'o-')
      legends.append(f'{backend}_{format}')
  plt.legend(legends)
  plt.show()

png

png

# Einsum
#
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
func = lambda mat1, mat2: jnp.einsum('ijk,jil->kl', mat1, mat2)
all_shapes = [([8, 1, 16], [1, 8, 4]), ([32, 2, 128], [2, 32, 64])]

all_times = {}
format = 'bcoo'
for mode in modes:
  all_times[mode] = {}
  for backend in backends:
    all_times[mode][(backend, format)] = []

shapes = []
for l_shape, r_shape in all_shapes:
  for nse in nses:
    shapes.append(f'{shapeStr(l_shape)}\n{shapeStr(r_shape)}\n(nse={nse})')

    mat1_d = gen_sparse_input(l_shape, nse, subkey)
    mat2_d = gen_sparse_input(r_shape, nse, subkey)
    checksum = func(mat1_d, mat2_d).sum()

    mat1 = jax.experimental.sparse.BCOO.fromdense(mat1_d)
    mat2 = jax.experimental.sparse.BCOO.fromdense(mat2_d)

    for mode in modes:
      for backend in backends:
        if (backend == 'JAX' and format == 'bcsr'): continue
        times = all_times[mode][(backend, format)]
        times.append(run(func, backend, mode, 'sparse', [mat1, mat2], checksum))
    del mat1, mat2
    del mat1_d, mat2_d

for mode in modes:
  fig = plt.figure()
  plt.subplots_adjust(bottom = 0.10)
  plt.title(f'{mode} time for JAX and MLIR sparse backend')
  plt.xlabel('Parameters');
  plt.ylabel('Einsum(a, b, ijk,jil->kl) microseconds');
  if mode == 'runtime':
    plt.yscale('log');
  legends = []
  for backend in backends:
    # Draw the time for the (backend, format).
    times = all_times[mode][(backend, format)]
    plt.plot(shapes, times, 'o-')
    legends.append(f'{backend}_{format}')
  plt.legend(legends)
  plt.show()

png

png