Sandbox for playing with Sparse JAX and the sparse compiler on GPU.

See colab. Currently, it is only available to googlers.


This colab can be executed on a special XLA-NEXT:CPU+ enabled GPU runtime.

The XLA-NEXT:CPU+ compiler (the plus for GPU!) is needed to accelerate sparse operations on the GPU. The obtain this compiler requires running from a custom-built colab binary, since the "GPU part" of XLA-NEXT is not in production.

Use the cell below to select between CUDA libgen and CUDA codegen for sparsified code (toggling this always needs a hard reset of the runtime):

  • xla_cpu_sparse_cuda_threads=1 CUDA LIBGEN, e.g. use cuSPARSE
  • xla_cpu_sparse_cuda_threads>1 CUDA CODEGEN, e.g. use cuda threads

Setup

Importing modules from P4 HEAD.

Profile SpMV on GPU

[ 0] Fraction of non-zero entries: 0.00000
  Dense ARRAY matvec completed in  304,892us
  Sparse BCOO SpMV   completed in       26us 0 nse
  Sparse BCSR SpMV   completed in      440us 0 nse
[ 1] Fraction of non-zero entries: 0.01000
  Dense ARRAY matvec completed in  300,976us
  Sparse BCOO SpMV   completed in    4,703us 2683428 nse
  Sparse BCSR SpMV   completed in    9,363us 2683428 nse
[ 2] Fraction of non-zero entries: 0.02000
  Dense ARRAY matvec completed in  301,270us
  Sparse BCOO SpMV   completed in   10,027us 5367424 nse
  Sparse BCSR SpMV   completed in   10,465us 5367424 nse
[ 3] Fraction of non-zero entries: 0.03000
  Dense ARRAY matvec completed in  300,594us
  Sparse BCOO SpMV   completed in   15,497us 8050085 nse
  Sparse BCSR SpMV   completed in   15,433us 8050085 nse
[ 4] Fraction of non-zero entries: 0.04000
  Dense ARRAY matvec completed in  300,976us
  Sparse BCOO SpMV   completed in   20,903us 10733923 nse
  Sparse BCSR SpMV   completed in   20,344us 10733923 nse
[ 5] Fraction of non-zero entries: 0.05000
  Dense ARRAY matvec completed in  300,844us
  Sparse BCOO SpMV   completed in   26,383us 13419841 nse
  Sparse BCSR SpMV   completed in   24,971us 13419841 nse
[ 6] Fraction of non-zero entries: 0.06000
  Dense ARRAY matvec completed in  302,289us
  Sparse BCOO SpMV   completed in   31,947us 16104215 nse
  Sparse BCSR SpMV   completed in   29,716us 16104215 nse
[ 7] Fraction of non-zero entries: 0.07000
  Dense ARRAY matvec completed in  300,723us
  Sparse BCOO SpMV   completed in   38,010us 18790488 nse
  Sparse BCSR SpMV   completed in   35,196us 18790488 nse
[ 8] Fraction of non-zero entries: 0.08000
  Dense ARRAY matvec completed in  301,040us
  Sparse BCOO SpMV   completed in   43,050us 21472986 nse
  Sparse BCSR SpMV   completed in   40,224us 21472986 nse
[ 9] Fraction of non-zero entries: 0.09000
  Dense ARRAY matvec completed in  300,776us
  Sparse BCOO SpMV   completed in   48,697us 24156594 nse
  Sparse BCSR SpMV   completed in   45,140us 24156594 nse

png

Profile SpMM on GPU

[ 0] Fraction of non-zero entries: 0.00000
  Dense ARRAY matmul completed in 3,090,630us
  Sparse BCSR SpMM   completed in    3,286us 0 nse
[ 1] Fraction of non-zero entries: 0.01000
  Dense ARRAY matmul completed in 2,911,031us
  Sparse BCSR SpMM   completed in    3,339us 10461 nse
[ 2] Fraction of non-zero entries: 0.02000
  Dense ARRAY matmul completed in 2,987,564us
  Sparse BCSR SpMM   completed in    3,507us 20727 nse
[ 3] Fraction of non-zero entries: 0.03000
  Dense ARRAY matmul completed in 2,959,827us
  Sparse BCSR SpMM   completed in    3,682us 31319 nse
[ 4] Fraction of non-zero entries: 0.04000
  Dense ARRAY matmul completed in 2,987,060us
  Sparse BCSR SpMM   completed in    4,288us 41781 nse
[ 5] Fraction of non-zero entries: 0.05000
  Dense ARRAY matmul completed in 3,000,965us
  Sparse BCSR SpMM   completed in    3,902us 52246 nse
[ 6] Fraction of non-zero entries: 0.06000
  Dense ARRAY matmul completed in 3,001,546us
  Sparse BCSR SpMM   completed in    4,162us 62740 nse
[ 7] Fraction of non-zero entries: 0.07000
  Dense ARRAY matmul completed in 2,950,979us
  Sparse BCSR SpMM   completed in    4,341us 73046 nse
[ 8] Fraction of non-zero entries: 0.08000
  Dense ARRAY matmul completed in 2,883,313us
  Sparse BCSR SpMM   completed in    4,383us 83590 nse
[ 9] Fraction of non-zero entries: 0.09000
  Dense ARRAY matmul completed in 2,876,997us
  Sparse BCSR SpMM   completed in    4,621us 94122 nse

png

Profile MV on CPU/GPU

[ 0] Fraction of non-zero entries: 0.00000
  Dense ARRAY matvec completed in    1,311us
[ 1] Fraction of non-zero entries: 0.01000
  Dense ARRAY matvec completed in    1,298us
[ 2] Fraction of non-zero entries: 0.02000
  Dense ARRAY matvec completed in    1,330us
[ 3] Fraction of non-zero entries: 0.03000
  Dense ARRAY matvec completed in    1,321us
[ 4] Fraction of non-zero entries: 0.04000
  Dense ARRAY matvec completed in    1,317us
[ 5] Fraction of non-zero entries: 0.05000
  Dense ARRAY matvec completed in    1,331us
[ 6] Fraction of non-zero entries: 0.06000
  Dense ARRAY matvec completed in    1,354us
[ 7] Fraction of non-zero entries: 0.07000
  Dense ARRAY matvec completed in    1,303us
[ 8] Fraction of non-zero entries: 0.08000
  Dense ARRAY matvec completed in    1,308us
[ 9] Fraction of non-zero entries: 0.09000
  Dense ARRAY matvec completed in    1,292us
(16384, 16384)
[1311.2015090882778, 1298.0406172573566, 1330.2603038027883, 1321.4902952313423, 1317.8099179640412, 1331.6298834979534, 1354.8784190788865, 1303.788903169334, 1308.6162973195314, 1292.1323999762535]

Profile GEMM on CPU/GPU

[ 0] Fraction of non-zero entries: 0.00000
  Dense ARRAY matmul completed in      248us
[ 1] Fraction of non-zero entries: 0.01000
  Dense ARRAY matmul completed in      262us
[ 2] Fraction of non-zero entries: 0.02000
  Dense ARRAY matmul completed in      240us
[ 3] Fraction of non-zero entries: 0.03000
  Dense ARRAY matmul completed in      291us
[ 4] Fraction of non-zero entries: 0.04000
  Dense ARRAY matmul completed in      277us
[ 5] Fraction of non-zero entries: 0.05000
  Dense ARRAY matmul completed in      264us
[ 6] Fraction of non-zero entries: 0.06000
  Dense ARRAY matmul completed in      245us
[ 7] Fraction of non-zero entries: 0.07000
  Dense ARRAY matmul completed in      243us
[ 8] Fraction of non-zero entries: 0.08000
  Dense ARRAY matmul completed in      234us
[ 9] Fraction of non-zero entries: 0.09000
  Dense ARRAY matmul completed in      240us
(1024, 1024)
[248.88459593057632, 262.53070682287216, 240.3387101367116, 291.3612173870206, 277.50979643315077, 264.9290021508932, 245.88459637016058, 243.75307839363813, 234.11749862134457, 240.35531096160412]