Stay organized with collections
Save and categorize content based on your preferences.
See colab. Currently, it is only available to googlers.
This Colab can run under DeepMind CPU Runtime.
Setup
import functools
import jax
# This is required to get Sparse JAX (experimental for now).
from jax.experimental import sparse
import jax.numpy as jnp
import os
from colabtools import adhoc_import
import timeit
from typing import Callable
CITC_CLIENT = ""
CITC_USER = ""
CL_NUMBER = ""
# Use XLA-NEXT for fast sparse ops
# --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops
# --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs.
# Eventually, it should be removed
os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'
Importing modules from P4 HEAD.
sparse convolution
import PIL
import math
from google3.pyglib import gfile
import numpy as np
import matplotlib.pyplot as plt
@sparse_jit
def conv(x, y):
padding = [(0, 0) for s in range(0, y.ndim)]
strides = tuple(1 for s in range(0, y.ndim))
return jax.lax.conv_general_dilated(x, y[None, None], window_strides=strides, padding=padding)
dn_image = np.array(PIL.Image.open(gfile.Open('/x20/users/pe/peiming/image/hello_sparsity.jpg', 'rb')).convert('L'), dtype='int8').reshape(1,1,960,960)
sp_image = sparse.BCSR.fromdense(dn_image, n_batch=2)
# A common diagnol edge detection kernel
kernel = jnp.array([[1, 1, 0], [1, 0,-1], [0,-1,-1]], dtype='int8')
#kernel = jnp.array([1,1,1,1,1,0,-1,-1,-1,-1,-1], dtype='int8')
#kernel = jax.lax.broadcast(kernel, (11,))
print("Sparsity =", sp_image.nse / float(math.prod(sp_image.shape)))
sparse_conv = lambda : conv(sp_image, kernel).block_until_ready()
dense_conv = lambda : conv(dn_image, kernel).block_until_ready()
#warmup
sparse_ret = sparse_conv()
dense_ret = dense_conv()
plt.imshow(dn_image[0,0,:,:].astype('uint8'))
Sparsity = 0.13420355902777778
<matplotlib.image.AxesImage at 0x7f51b80d51e0>

REPEATS = 1
LOOPS = 10
SECONDS_TO_MICROS = 1_000 * 1_000
t = timeit.repeat(
sparse_conv,
repeat=REPEATS,
number=LOOPS,
)
print("Sparse Convolution Time(ms):", SECONDS_TO_MICROS * max(t) / LOOPS)
t = timeit.repeat(
dense_conv,
repeat=REPEATS,
number=LOOPS,
)
print("Dense Convolution Time(ms):", SECONDS_TO_MICROS * max(t) / LOOPS)
Sparse Convolution Time(ms): 5355.268900166266
Dense Convolution Time(ms): 12721.699800749775
print("sparse convolution result.")
plt.imshow(sparse_ret[0,0,:,:].astype('uint8'))
sparse convolution result.
<matplotlib.image.AxesImage at 0x7efdada839a0>

print("PEIMING: dense convolution result.")
plt.imshow(dense_ret[0,0,:,:].astype('uint8'))
PEIMING: dense convolution result.
<matplotlib.image.AxesImage at 0x7fe603184b80>

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-03 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-03 UTC."],[[["\u003cp\u003eThis Colab demonstrates the usage of sparse convolution in JAX with a focus on performance.\u003c/p\u003e\n"],["\u003cp\u003eIt compares the execution time of sparse and dense convolution on an image, showcasing significant speed improvements with sparse representation (over 2x faster).\u003c/p\u003e\n"],["\u003cp\u003eThe Colab utilizes DeepMind CPU Runtime and leverages XLA-NEXT for optimized sparse operations.\u003c/p\u003e\n"],["\u003cp\u003eIt provides a practical example of edge detection using a common kernel applied to a sparse image representation.\u003c/p\u003e\n"]]],[],null,["See [colab](https://colab.corp.google.com/drive/1SlSXVchyYo4qd2AAtbEV3HZvsX65_yrn). Currently, it is only available to googlers.\n\nThis Colab can run under DeepMind CPU Runtime.\n\n### Setup\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n import functools\n import jax\n # This is required to get Sparse JAX (experimental for now).\n from jax.experimental import sparse\n import jax.numpy as jnp\n import os\n\n from colabtools import adhoc_import\n import timeit\n from typing import Callable\n\n\n\n\n CITC_CLIENT = \"\" \n\n CITC_USER = \"\" \n\n CL_NUMBER = \"\"\n\n\n # Use XLA-NEXT for fast sparse ops\n # --xla_cpu_use_xla_runtime enables XLA-NEXT which has fast sparse ops\n # --xla_cpu_enable_mlir_tiling_and_fusion=false needed to workaround a few bugs.\n # Eventually, it should be removed\n os.environ['XLA_FLAGS'] = '--xla_cpu_use_xla_runtime --xla_cpu_enable_mlir_tiling_and_fusion=false'\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nImporting modules from P4 HEAD.\n```\n\n### sparse convolution\n\n\u003cbr /\u003e\n\nToggle code\n\n\u003cbr /\u003e\n\n import PIL\n import math\n from google3.pyglib import gfile\n import numpy as np\n import matplotlib.pyplot as plt\n\n @sparse_jit\n def conv(x, y):\n padding = [(0, 0) for s in range(0, y.ndim)]\n strides = tuple(1 for s in range(0, y.ndim))\n return jax.lax.conv_general_dilated(x, y[None, None], window_strides=strides, padding=padding)\n\n\n dn_image = np.array(PIL.Image.open(gfile.Open('/x20/users/pe/peiming/image/hello_sparsity.jpg', 'rb')).convert('L'), dtype='int8').reshape(1,1,960,960)\n sp_image = sparse.BCSR.fromdense(dn_image, n_batch=2)\n # A common diagnol edge detection kernel\n kernel = jnp.array([[1, 1, 0], [1, 0,-1], [0,-1,-1]], dtype='int8')\n #kernel = jnp.array([1,1,1,1,1,0,-1,-1,-1,-1,-1], dtype='int8')\n #kernel = jax.lax.broadcast(kernel, (11,))\n\n print(\"Sparsity =\", sp_image.nse / float(math.prod(sp_image.shape)))\n\n sparse_conv = lambda : conv(sp_image, kernel).block_until_ready()\n dense_conv = lambda : conv(dn_image, kernel).block_until_ready()\n\n #warmup\n sparse_ret = sparse_conv()\n dense_ret = dense_conv()\n plt.imshow(dn_image[0,0,:,:].astype('uint8'))\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n```\nSparsity = 0.13420355902777778\n\u003cmatplotlib.image.AxesImage at 0x7f51b80d51e0\u003e\n```\n\n REPEATS = 1\n LOOPS = 10\n SECONDS_TO_MICROS = 1_000 * 1_000\n\n t = timeit.repeat(\n sparse_conv,\n repeat=REPEATS,\n number=LOOPS,\n )\n print(\"Sparse Convolution Time(ms):\", SECONDS_TO_MICROS * max(t) / LOOPS)\n\n t = timeit.repeat(\n dense_conv,\n repeat=REPEATS,\n number=LOOPS,\n )\n print(\"Dense Convolution Time(ms):\", SECONDS_TO_MICROS * max(t) / LOOPS)\n\n```\nSparse Convolution Time(ms): 5355.268900166266\nDense Convolution Time(ms): 12721.699800749775\n``` \n\n print(\"sparse convolution result.\")\n plt.imshow(sparse_ret[0,0,:,:].astype('uint8'))\n\n```\nsparse convolution result.\n\u003cmatplotlib.image.AxesImage at 0x7efdada839a0\u003e\n```\n\n print(\"PEIMING: dense convolution result.\")\n plt.imshow(dense_ret[0,0,:,:].astype('uint8'))\n\n```\nPEIMING: dense convolution result.\n\u003cmatplotlib.image.AxesImage at 0x7fe603184b80\u003e\n```"]]