Sandbox for playing with Sparse JAX and the sparse compiler on GPU.

See colab. Currently, it is only available to googlers.


This colab can be executed on a special XLA-NEXT:CPU+ enabled GPU runtime.

The XLA-NEXT:CPU+ compiler (the plus for GPU!) is needed to accelerate sparse operations on the GPU. The obtain this compiler requires running from a custom-built colab binary, since the "GPU part" of XLA-NEXT is not in production.

Use the cell below to select between CUDA libgen and CUDA codegen for sparsified code (toggling this always needs a hard reset of the runtime):

  • xla_cpu_sparse_cuda_threads=1 CUDA LIBGEN, e.g. use cuSPARSE
  • xla_cpu_sparse_cuda_threads>1 CUDA CODEGEN, e.g. use cuda threads

Setup

Importing modules from P4 HEAD.

Profile SDDMM on GPU

[ 0] Fraction of non-zero entries: 0.00000
  Dense  ARRAY SDDMM completed in 2,910,600us
  Sparse BCOO  SDDMM completed in      274us 0 nse
  Sparse BCSR  SDDMM completed in    1,035us 0 nse
  Sparse GPU   SDDMM completed in    2,684us 0 nse
[ 1] Fraction of non-zero entries: 0.10000
  Dense  ARRAY SDDMM completed in 2,871,439us
  Sparse BCOO  SDDMM completed in  159,393us 104638 nse
  Sparse BCSR  SDDMM completed in  177,083us 104638 nse
  Sparse GPU   SDDMM completed in   17,192us 104638 nse
[ 2] Fraction of non-zero entries: 0.20000
  Dense  ARRAY SDDMM completed in 2,928,734us
  Sparse BCOO  SDDMM completed in  219,174us 209835 nse
  Sparse BCSR  SDDMM completed in  283,097us 209835 nse
  Sparse GPU   SDDMM completed in    7,729us 209835 nse
[ 3] Fraction of non-zero entries: 0.30000
  Dense  ARRAY SDDMM completed in 3,022,233us
  Sparse BCOO  SDDMM completed in  328,827us 314145 nse
  Sparse BCSR  SDDMM completed in  427,342us 314145 nse
  Sparse GPU   SDDMM completed in    8,328us 314145 nse
[ 4] Fraction of non-zero entries: 0.40000
  Dense  ARRAY SDDMM completed in 2,898,277us
  Sparse BCOO  SDDMM completed in  422,196us 419046 nse
  Sparse BCSR  SDDMM completed in  520,667us 419046 nse
  Sparse GPU   SDDMM completed in    9,146us 419046 nse
[ 5] Fraction of non-zero entries: 0.50000
  Dense  ARRAY SDDMM completed in 2,941,023us
  Sparse BCOO  SDDMM completed in  519,081us 523768 nse
  Sparse BCSR  SDDMM completed in  635,784us 523768 nse
  Sparse GPU   SDDMM completed in    9,952us 523768 nse
[ 6] Fraction of non-zero entries: 0.60000
  Dense  ARRAY SDDMM completed in 2,872,818us
  Sparse BCOO  SDDMM completed in  623,089us 628597 nse
  Sparse BCSR  SDDMM completed in  745,292us 628597 nse
  Sparse GPU   SDDMM completed in   11,055us 628597 nse
[ 7] Fraction of non-zero entries: 0.70000
  Dense  ARRAY SDDMM completed in 2,911,507us
  Sparse BCOO  SDDMM completed in  726,227us 734097 nse
  Sparse BCSR  SDDMM completed in  865,237us 734097 nse
  Sparse GPU   SDDMM completed in   11,420us 734097 nse
[ 8] Fraction of non-zero entries: 0.80000
  Dense  ARRAY SDDMM completed in 2,821,479us
  Sparse BCOO  SDDMM completed in  846,639us 838722 nse
  Sparse BCSR  SDDMM completed in 1,002,814us 838722 nse
  Sparse GPU   SDDMM completed in   11,851us 838722 nse
[ 9] Fraction of non-zero entries: 0.90000
  Dense  ARRAY SDDMM completed in 2,861,378us
  Sparse BCOO  SDDMM completed in  948,079us 943895 nse
  Sparse BCSR  SDDMM completed in 1,126,073us 943895 nse
  Sparse GPU   SDDMM completed in   13,150us 943895 nse

png

Profile 2:4 on GPU

Dense ARRAY MatMul24 completed in       12us (16, 16)
  Sparse GPU NV2:4 completed in    1,555us (16, 16)
  Dense ARRAY MatMul24 completed in       41us (32, 32)
  Sparse GPU NV2:4 completed in    4,083us (32, 32)
  Dense ARRAY MatMul24 completed in      357us (64, 64)
  Sparse GPU NV2:4 completed in    3,351us (64, 64)
  Dense ARRAY MatMul24 completed in    3,003us (128, 128)
  Sparse GPU NV2:4 completed in      836us (128, 128)
  Dense ARRAY MatMul24 completed in   25,408us (256, 256)
  Sparse GPU NV2:4 completed in    1,465us (256, 256)
  Dense ARRAY MatMul24 completed in  254,398us (512, 512)
  Sparse GPU NV2:4 completed in    1,235us (512, 512)
  Dense ARRAY MatMul24 completed in 3,703,789us (1024, 1024)
  Sparse GPU NV2:4 completed in    4,868us (1024, 1024)
  Dense ARRAY MatMul24 completed in 33,101,322us (2048, 2048)
  Sparse GPU NV2:4 completed in   15,398us (2048, 2048)

png