Thanks! I was indeed misunderstanding the situation with the JAX calling convention. I’ve made the suggested changes to the MLIR/calling code by removing the format definition entirely for dense tensors. If I change the shapes to dynamic on all tensors and invoke mlir-opt
with --sparse-assembler="direct-out=true"
(which was slightly alluded to work with dynamic shapes) or --sparse-assembler
, I get the following error:
./sddmm.mlir:25:1: error: the sparse-tensor must have static shape
func.func @sddmm_kernel(
^
./sddmm.mlir:25:1: note: see current operation: %0 = "sparse_tensor.assemble"(%arg2, %arg3, %arg4) : (tensor<?xindex>, tensor<?xindex>, tensor<?xf64>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>
However, everything compiles again if I switch to static shapes, and from performance measurements, it’s fusing correctly, and now also returning correct results, probably a weird issue with my memory management.
For anyone interested, here is my code:
sddmm.mlir
// Define the CSR format
#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0: dense, d1: compressed)
}>
// Define the iterator trait for SDDMM
#trait_sddmm = {
// Defines the equivalent of "index notation" for a GUFunc
// For each input/output, define the input indices and output indices
indexing_maps = [
affine_map<(i, j, k) -> (i, k)>, // first matmul arg
affine_map<(i, j, k) -> (k, j)>, // second matmul arg
affine_map<(i, j, k) -> (i, j)>, // sampling matrix
affine_map<(i, j, k) -> (i, j)> // output
],
// For each index (i, j, k in this instance), define the kind of iteration
// parallel -> We can iterate over this in parallel
// reduction -> We perform a reduction over this dimension.
iterator_types = ["parallel", "parallel", "reduction"],
// An optional docstring.
doc = "O(i,j) = C(i, j) * SUM_k A(i, k) * B(k, j)"
}
// Define the function
func.func @sddmm_kernel(
// 1st argument, dense 2D
%a: tensor<10000x10000xf64>,
// 2nd argument, likewise
%b: tensor<10000x10000xf64>,
// Third argument; 2D CSR
%c: tensor<10000x10000xf64, #CSR>
) -> tensor<10000x10000xf64, #CSR>
attributes { llvm.emit_c_interface } {
// Define constants 0 and 1
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// Get the dimension sizes of the output
%d0 = tensor.dim %a, %c0 : tensor<10000x10000xf64>
%d1 = tensor.dim %b, %c1 : tensor<10000x10000xf64>
// sparse_tensor.print %c : tensor<10000x10000xf64, #CSR>
// Define the scratch output array that will be invalidated in CSR format
%out = tensor.empty() : tensor<10000x10000xf64, #CSR>
// Define the kernel
// `linalg.generic` is like gufuncs in NumPy
%sddmm = linalg.generic #trait_sddmm
ins(%a, %b, %c : tensor<10000x10000xf64>, tensor<10000x10000xf64>, tensor<10000x10000xf64, #CSR>)
outs(%out : tensor<10000x10000xf64, #CSR>) {
// This defines the inner loop
^bb0(%a_one: f64, %b_one: f64, %c_one: f64, %init_one: f64):
// This part makes %1 "empty"/"non-present" when %c_one is missing.
%1 = arith.mulf %a_one, %b_one : f64
%2 = sparse_tensor.binary %1, %c_one : f64, f64 to f64
overlap={
^bb0(%arg0: f64, %arg1: f64):
%0 = arith.mulf %arg0, %arg1 : f64
sparse_tensor.yield %0 : f64
}
left={}
right={}
%3 = arith.addf %init_one, %2 : f64
linalg.yield %3 : f64
} -> tensor<10000x10000xf64, #CSR>
// sparse_tensor.print %sddmm : tensor<10000x10000xf64, #CSR>
return %sddmm : tensor<10000x10000xf64, #CSR>
}
compile.sh
#!/usr/bin/env bash
set -exuo pipefail
SCRIPT_LOCATION=$(dirname "$0")
LLVM_BUILD="/Users/habbasi/Quansight/llvm-project/build"
LLVM_BIN="$LLVM_BUILD/bin"
# Optimize the MLIR and emit certain utilities
"$LLVM_BIN/mlir-opt" "$SCRIPT_LOCATION/sddmm.mlir" \
--sparse-assembler="direct-out=False" \
--sparsifier="enable-runtime-library=false" \
-llvm-request-c-wrappers \
-lower-affine -convert-scf-to-cf -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts \
-o "$SCRIPT_LOCATION/sddmm_opt.mlir"
# Convert MLIR to LLVM IR
"$LLVM_BIN/mlir-translate" "$SCRIPT_LOCATION/sddmm_opt.mlir" \
--mlir-to-llvmir \
-o "$SCRIPT_LOCATION/sddmm_opt.ll"
# Convert LLVM IR to optimized bytecode
"$LLVM_BIN/clang++" "$SCRIPT_LOCATION/sddmm_opt.ll" \
-flto -dynamiclib -fuse-ld=lld \
-L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib \
-O2 -funsafe-math-optimizations -fno-trapping-math \
-c -o "$SCRIPT_LOCATION/a.o"
"$LLVM_BIN/clang++" "$SCRIPT_LOCATION/a.o" \
"$LLVM_BUILD/tools/mlir/lib/ExecutionEngine/CMakeFiles/mlir_runner_utils.dir/RunnerUtils.cpp.o" \
"$LLVM_BUILD/tools/mlir/lib/ExecutionEngine/CMakeFiles/mlir_c_runner_utils.dir/CRunnerUtils.cpp.o" \
"$LLVM_BUILD/tools/mlir/lib/ExecutionEngine/CMakeFiles/mlir_float16_utils.dir/Float16bits.cpp.o" \
-flto -dynamiclib -fuse-ld=lld \
-L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib \
-O2 -funsafe-math-optimizations -fno-trapping-math \
-o "$SCRIPT_LOCATION/a.dylib"
# Read the symbol table
"$LLVM_BIN/llvm-nm" -C "$SCRIPT_LOCATION/a.dylib"
sddmm_test.py
import numpy as np
import ctypes
import typing
import scipy.sparse as sps
import weakref
import time
c_intptr = np.ctypeslib.c_intp
def _get_fields(*, ndim: int, dtype: np.dtype):
ptr_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))
return [
("allocated", ptr_type),
("aligned", ptr_type),
("offset", c_intptr),
("sizes", c_intptr * ndim),
("strides", c_intptr * ndim),
]
PREVENT_GC_C2 = {}
PREVENT_GC_C = weakref.WeakValueDictionary()
class NumpyConversionMixin(ctypes.Structure):
NDIM: typing.ClassVar[int]
DTYPE: typing.ClassVar[np.dtype]
@classmethod
def from_numpy(cls: typing.Type[ctypes.Structure], arr: np.ndarray) -> typing.Self:
ptr_type = cls._fields_[0][1]
data_ptr = arr.ctypes.data_as(ptr_type)
offset = 0
sizes = cls._fields_[3][1](*arr.shape)
strides = cls._fields_[4][1](*map(lambda s: s // arr.itemsize, arr.strides))
ret = cls(data_ptr, data_ptr, offset, sizes, strides)
return ret
def to_numpy(self: ctypes.Structure) -> np.ndarray:
""" CAUTION: Assumes self owns the data. """
ret = np.ctypeslib.as_array(self.aligned, shape=tuple(self.sizes))
PREVENT_GC_C[self] = ret
PREVENT_GC_C2[self] = ret
return ret
@classmethod
def new(cls, *, shape: int | tuple[int, ...]) -> typing.Self:
return cls.from_numpy(np.empty(shape, dtype=cls.DTYPE))
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, value) -> bool:
return self is value
def __del__(self) -> None:
if PREVENT_GC_C2.pop(self, None) is not None:
cdll.free(self.allocated)
class Memref2DDouble(NumpyConversionMixin):
DTYPE = np.dtype(np.float64)
NDIM = 2
_fields_ = _get_fields(ndim=NDIM, dtype=DTYPE)
class Memref1DDouble(NumpyConversionMixin):
DTYPE = np.dtype(np.float64)
NDIM = 1
_fields_ = _get_fields(ndim=NDIM, dtype=DTYPE)
class Memref1DIndex(NumpyConversionMixin):
DTYPE = np.dtype(np.uintp)
NDIM = 1
_fields_ = _get_fields(ndim=NDIM, dtype=DTYPE)
class CsrDouble(ctypes.Structure):
_fields_ = [
("indptr", Memref1DIndex),
("indices", Memref1DIndex),
("data", Memref1DDouble),
]
@classmethod
def from_scipy_sparse(cls, arr: sps.csr_array | sps.csr_matrix) -> typing.Self:
indptr = Memref1DIndex.from_numpy(arr.indptr.astype(np.int64))
indices = Memref1DIndex.from_numpy(arr.indices.astype(np.int64))
data = Memref1DDouble.from_numpy(arr.data.astype(np.float64))
return cls(indptr=indptr, indices=indices, data=data)
def to_scipy_sparse(self, *, shape: tuple[int, int]) -> sps.csr_array:
indptr = self.indptr.to_numpy()[:shape[0]+1]
indices = self.indices.to_numpy()[indptr[0]:indptr[-1]]
data = self.data.to_numpy()[indptr[0]:indptr[-1]]
if indptr[0] != 0:
indptr -= indptr[0]
return sps.csr_array((data, indices, indptr), shape=shape)
@classmethod
def new(cls, *, shape: tuple[int, int], nnz: int) -> typing.Self:
return cls(indptr=Memref1DIndex.new(shape=shape[0] + 1), indices=Memref1DIndex.new(shape=nnz), data=Memref1DDouble.new(shape=nnz))
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, value: object) -> bool:
return self is value
cdll = ctypes.CDLL("/Users/habbasi/Quansight/llvm-project/scratch/a.dylib")
cdll._mlir_ciface_sddmm_kernel.restype = None
cdll.free.restype = None
def debug_print(name: str, arr: np.ndarray | sps.csr_array) -> None:
print(f"--- Printing {name} ---")
if isinstance(arr, np.ndarray):
print(f"{arr=}")
elif isinstance(arr, sps.csr_array | sps.csr_matrix):
print(f"{arr.indptr=}")
print(f"{arr.indices=}")
print(f"{arr.data=}")
print(f"--- End Printing {name} ---")
def sddmm(a: np.ndarray, b: np.ndarray, c: sps.csr_array | sps.csr_matrix) -> sps.csr_array:
out_shape = c.shape
nnz = c.nnz
a_ref = Memref2DDouble.from_numpy(a)
b_ref = Memref2DDouble.from_numpy(b)
c_indptr_tmp = c.indptr.astype(np.int64)
c_indices_tmp = c.indices.astype(np.int64)
c_intptr_ref = Memref1DIndex.from_numpy(c_indptr_tmp)
c_indices_ref = Memref1DIndex.from_numpy(c_indices_tmp)
c_data_ref = Memref1DDouble.from_numpy(c.data)
out = CsrDouble()
if not DIRECT_OUT:
out_indptr = np.empty(shape=out_shape[0]+1, dtype=np.int64)
out_indices = np.empty(shape=nnz, dtype=np.int64)
out_data = np.empty(shape=nnz, dtype=np.float64)
out_intptr_ref = Memref1DIndex.from_numpy(out_indptr)
out_indices_ref = Memref1DIndex.from_numpy(out_indices)
out_data_ref = Memref1DDouble.from_numpy(out_data)
cdll._mlir_ciface_sddmm_kernel(
ctypes.byref(out),
ctypes.byref(a_ref),
ctypes.byref(b_ref),
ctypes.byref(c_intptr_ref), ctypes.byref(c_indices_ref), ctypes.byref(c_data_ref),
ctypes.byref(out_intptr_ref), ctypes.byref(out_indices_ref), ctypes.byref(out_data_ref),
)
return sps.csr_array((out_data, out_indices, out_indptr), shape=out_shape)
else:
cdll._mlir_ciface_sddmm_kernel(
ctypes.byref(out),
ctypes.byref(a_ref),
ctypes.byref(b_ref),
ctypes.byref(c_intptr_ref), ctypes.byref(c_indices_ref), ctypes.byref(c_data_ref),
)
ret = out.to_scipy_sparse(shape=out_shape)
return ret
if __name__ == "__main__":
LEN = 10000
DENSITY = 0.0001
DIRECT_OUT = False
arr1 = np.random.random((LEN, LEN))
arr2 = np.random.random((LEN, LEN))
arr3 = sps.random_array((LEN, LEN), format="csr", density=DENSITY)
start = time.time_ns()
actual = sddmm(arr1, arr2, arr3)
print(f"Fused SDDMM took {(time.time_ns() - start) / 10 ** 9} s.")
start = time.time_ns()
expected = ((arr1 @ arr2) * arr3).asformat("csr")
print(f"Non-fused SDDMM took {(time.time_ns() - start) / 10 ** 9} s.")
np.testing.assert_equal(actual.indptr, expected.indptr)
np.testing.assert_equal(actual.indices, expected.indices)
np.testing.assert_allclose(actual.data, expected.data)
print("Success!")
A couple of final questions/requests:
- Dynamic shapes on CSR don’t seem to work even with
direct-out=True
. Is that intended? Am I compiling it wrong, perhaps? Can I help/request for it to be added?
- The NumPy version is using parallelism to speed things up, which makes it hard to beat. However, in this particular instance, there is a simple way to speed up the dot product inside the SDDMM. Is emitting that kind of parallelism supported?
- The sparsity structure can just be copied in this instance without the need for sparse inserts. Can we perhaps emit that sort of code?