See colab. Currently, it is only available to googlers.
This Colab can run under DeepMind CPU Runtime.
Setup
import functools
import jax
# This is required to get Sparse JAX (experimental for now).
from jax.experimental import sparse
import jax.numpy as jnp
import os
from colabtools import adhoc_import
import timeit
from typing import Callable
CITC_CLIENT = ""
CITC_USER = ""
CL_NUMBER = ""
# Use XLA-NEXT for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs.
# Eventually, it should be removed
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'
Importing modules from P4 HEAD.
sparse convolution
import PIL
import math
from google3.pyglib import gfile
import numpy as np
import matplotlib.pyplot as plt
@sparse_jit
def conv(x, y):
padding = [(0, 0) for s in range(0, y.ndim)]
strides = tuple(1 for s in range(0, y.ndim))
return jax.lax.conv_general_dilated(x, y[None, None], window_strides=strides, padding=padding)
dn_image = np.array(PIL.Image.open(gfile.Open('/x20/users/pe/peiming/image/hello_sparsity.jpg', 'rb')).convert('L'), dtype='int8').reshape(1,1,960,960)
sp_image = sparse.BCSR.fromdense(dn_image, n_batch=2)
# A common diagnol edge detection kernel
kernel = jnp.array([[1, 1, 0], [1, 0,-1], [0,-1,-1]], dtype='int8')
#kernel = jnp.array([1,1,1,1,1,0,-1,-1,-1,-1,-1], dtype='int8')
#kernel = jax.lax.broadcast(kernel, (11,))
print("Sparsity =", sp_image.nse / float(math.prod(sp_image.shape)))
sparse_conv = lambda : conv(sp_image, kernel).block_until_ready()
dense_conv = lambda : conv(dn_image, kernel).block_until_ready()
#warmup
sparse_ret = sparse_conv()
dense_ret = dense_conv()
plt.imshow(dn_image[0,0,:,:].astype('uint8'))
Sparsity = 0.13420355902777778 <matplotlib.image.AxesImage at 0x7f51b80d51e0>
REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
t = timeit.repeat(
sparse_conv,
repeat=REPEATS,
number=LOOPS,
)
print("Sparse Convolution Time(ms):", SECONDS_TO_MICROS * max(t) / LOOPS)
t = timeit.repeat(
dense_conv,
repeat=REPEATS,
number=LOOPS,
)
print("Dense Convolution Time(ms):", SECONDS_TO_MICROS * max(t) / LOOPS)
Sparse Convolution Time(ms): 5355.268900166266 Dense Convolution Time(ms): 12721.699800749775
print("sparse convolution result.")
plt.imshow(sparse_ret[0,0,:,:].astype('uint8'))
sparse convolution result. <matplotlib.image.AxesImage at 0x7efdada839a0>
print("PEIMING: dense convolution result.")
plt.imshow(dense_ret[0,0,:,:].astype('uint8'))
PEIMING: dense convolution result. <matplotlib.image.AxesImage at 0x7fe603184b80>