Sandbox for playing with Sparse JAX and the sparse compiler on CPU.

See colab. Currently, it is only available to googlers.


This colab can be executed on Brain/DeepMind CPU runtime.

We need to use XLA-NEXT:CPU for fast sparse ops: the next cell shows how to set it up, and it allows to import sparse_jit either from P4 HEAD (the default) or from a CitC workspace (this only picks up Python related changes though).

You must restart the runtime before re-running the cell below to pick up changes:

  • soft restart: go to Runtime, Restart Runtime
  • hard restart: go to Connect to a Borg Runtime, Stop Runtime, Start New

Setup

Importing modules from P4 HEAD.

Hello, sparse world!

Hello sparse world
BCOO(float32[1000], nse=6) <class 'jax.experimental.sparse.bcoo.BCOO'>
Computed 91.0

Profile matmul vs sparsity

[ 0] Fraction of non-zero entries: 0.00000
  Dense ARRAY matmul completed in      601us
  Sparse BCOO matmul completed in       17us
  Sparse BCSR matmul completed in       16us
[ 1] Fraction of non-zero entries: 0.10000
  Dense ARRAY matmul completed in      590us
  Sparse BCOO matmul completed in      282us
  Sparse BCSR matmul completed in       72us
[ 2] Fraction of non-zero entries: 0.20000
  Dense ARRAY matmul completed in      594us
  Sparse BCOO matmul completed in      548us
  Sparse BCSR matmul completed in      148us
[ 3] Fraction of non-zero entries: 0.30000
  Dense ARRAY matmul completed in      597us
  Sparse BCOO matmul completed in      717us
  Sparse BCSR matmul completed in      215us
[ 4] Fraction of non-zero entries: 0.40000
  Dense ARRAY matmul completed in      590us
  Sparse BCOO matmul completed in      818us
  Sparse BCSR matmul completed in      295us
[ 5] Fraction of non-zero entries: 0.50000
  Dense ARRAY matmul completed in      590us
  Sparse BCOO matmul completed in      898us
  Sparse BCSR matmul completed in      374us
[ 6] Fraction of non-zero entries: 0.60000
  Dense ARRAY matmul completed in      596us
  Sparse BCOO matmul completed in      965us
  Sparse BCSR matmul completed in      457us
[ 7] Fraction of non-zero entries: 0.70000
  Dense ARRAY matmul completed in      591us
  Sparse BCOO matmul completed in    1,001us
  Sparse BCSR matmul completed in      530us
[ 8] Fraction of non-zero entries: 0.80000
  Dense ARRAY matmul completed in      596us
  Sparse BCOO matmul completed in    1,009us
  Sparse BCSR matmul completed in      598us
[ 9] Fraction of non-zero entries: 0.90000
  Dense ARRAY matmul completed in      595us
  Sparse BCOO matmul completed in      972us
  Sparse BCSR matmul completed in      658us

png

Replicate matmul results from previous cell using nn.Dense model

[ 0] Fraction of non-zero entries: 0.00001
  Dense  model completed in      601us
  Sparse model completed in       17us
[ 1] Fraction of non-zero entries: 0.00003
  Dense  model completed in      601us
  Sparse model completed in       15us
[ 2] Fraction of non-zero entries: 0.00010
  Dense  model completed in      608us
  Sparse model completed in       15us
[ 3] Fraction of non-zero entries: 0.00032
  Dense  model completed in      598us
  Sparse model completed in       16us
[ 4] Fraction of non-zero entries: 0.00100
  Dense  model completed in      597us
  Sparse model completed in       18us
[ 5] Fraction of non-zero entries: 0.00316
  Dense  model completed in      593us
  Sparse model completed in       21us
[ 6] Fraction of non-zero entries: 0.01000
  Dense  model completed in      612us
  Sparse model completed in       37us
[ 7] Fraction of non-zero entries: 0.03162
  Dense  model completed in      602us
  Sparse model completed in       89us
[ 8] Fraction of non-zero entries: 0.10000
  Dense  model completed in      613us
  Sparse model completed in      310us
[ 9] Fraction of non-zero entries: 0.31623
  Dense  model completed in      596us
  Sparse model completed in      796us
[10] Fraction of non-zero entries: 1.00000
  Dense  model completed in      597us
  Sparse model completed in      671us

png

Sparse Jax Specifics

Sparse JAX BCOO support sparsification of:
dict_keys([abs, asin, asinh, atan, atanh, bessel_i1e, expm1, log1p, sign, sin, sinh, sqrt, tan, tanh, convert_element_type, copy, imag, neg, real, broadcast_in_dim, concatenate, conv_general_dilated, dot_general, dynamic_slice, reshape, rev, slice, squeeze, integer_pow, transpose, add, sub, mul, div, reduce_sum, gather, while, pjit, scan, cond, todense, custom_jvp_call])
Sparse JAX BCCSR support sparsification of:
dict_keys([abs, asin, asinh, atan, atanh, bessel_i1e, expm1, log1p, sign, sin, sinh, sqrt, tan, tanh, convert_element_type, copy, imag, neg, real, dot_general, broadcast_in_dim, concatenate, integer_pow, todense, custom_jvp_call])
{ lambda ; a:f32[6] b:f32[6]. let c:f32[6] = add a b in (c,) }
{ lambda ; a:f32[6] b:f32[6] c:i32[6,1]. let
    d:f32[6] = bcoo_todense[
      spinfo=SparseInfo(shape=(6,), indices_sorted=True, unique_indices=True)
    ] b c
    e:f32[6] = add d a
  in (e,) }
{ lambda ; a:f32[6] b:i32[6,1] c:f32[6] d:i32[6,1]. let
    e:i32[12,1] = concatenate[dimension=0] b d
    f:f32[12] = concatenate[dimension=0] a c
  in (f, e) }