Compiling into a shared object and calling MLIR from external code

Hello, I have the following MLIR code, and I’d like to compile it into a working shared object via the CLI tools provided by LLVM.

// Define the CSR format
#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0: dense, d1: compressed)
}>

// Define the 2D Dense format
#Dense2D = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0: dense, d1: dense)
}>

// Define the iterator trait for SDDMM
#trait_sddmm = {
    // Defines the equivalent of "index notation" for a GUFunc
    // For each input/output, define the input indices and output indices
    indexing_maps = [
        affine_map<(i, j, k) -> (i, k)>,    // first matmul arg
        affine_map<(i, j, k) -> (k, j)>,    // second matmul arg
        affine_map<(i, j, k) -> (i, j)>,    // sampling matrix
        affine_map<(i, j, k) -> (i, j)>     // output
    ],
    // For each index (i, j, k in this instance), define the kind of iteration
    // parallel -> We can iterate over this in parallel
    // reduction -> We perform a reduction over this dimension.
    iterator_types = ["parallel", "parallel", "reduction"],
    // An optional docstring.
    doc = "O(i,j) = C(i, j) * SUM_k A(i, k) * B(k, j)"
}

// Define the function
func.func @sddmm_kernel(
    // 1st argument, dense 2D
    %a: tensor<?x?xf64, #Dense2D>,
    // 2nd argument, likewise
    %b: tensor<?x?xf64, #Dense2D>,
    // Third argument; 2D CSR
    %c: tensor<?x?xf64, #CSR>
) -> tensor<?x?xf64, #CSR> {
    // Define constants 0 and 1
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    // Get the dimension sizes of the output
    %d0 = tensor.dim %a, %c0 : tensor<?x?xf64, #Dense2D>
    %d1 = tensor.dim %b, %c1 : tensor<?x?xf64, #Dense2D>
    // Define the scratch output array that will be invalidated in CSR format
    %out = tensor.empty(%d0, %d1) : tensor<?x?xf64, #CSR>
    // Define the kernel
    // `linalg.generic` is like gufuncs in NumPy
    %sddmm = linalg.generic #trait_sddmm
    ins(%a, %b, %c : tensor<?x?xf64, #Dense2D>, tensor<?x?xf64, #Dense2D>, tensor<?x?xf64, #CSR>)
    outs(%out : tensor<?x?xf64, #CSR>) {
    ^bb0(%a_one: f64, %b_one: f64, %c_one: f64, %init_one: f64):
        // This defines the inner loop
        %1 = arith.mulf %c_one, %a_one : f64
        %2 = arith.mulf %1, %b_one : f64
        %3 = arith.addf %init_one, %2 : f64
        linalg.yield %3 : f64
    } -> tensor<?x?xf64, #CSR>

    return %sddmm : tensor<?x?xf64, #CSR>
}

I’m currently at the following stage: If I run the following:

clear; time ./bin/mlir-opt '/path/to/sddmm.mlir' --sparsifier="enable-arm-sve=true" -llvm-request-c-wrappers -convert-func-to-llvm | ./bin/mlir-translate --mlir-to-llvmir | ./bin/clang -O2 -c -x ir - -o - | ./bin/llvm-nm -C -

I get the following output:

                 U __mlir_ciface_expInsertF64
                 U __mlir_ciface_newSparseTensor
                 U __mlir_ciface_sparseCoordinates0
                 U __mlir_ciface_sparsePositions0
                 U __mlir_ciface_sparseValuesF64
                 U _bzero
                 U _endLexInsert
                 U _free
                 U _malloc
0000000000000000 T _sddmm_kernel
                 U _sparseLvlSize
00000000000003c0 s lCPI0_0
00000000000003d0 s lCPI0_1
00000000000003e0 s lCPI0_2
0000000000000000 t ltmp0
00000000000003c0 s ltmp1
00000000000003f0 s ltmp2

From this I can see I got a couple of things right, most importantly the _sddmm_kernel symbol is, indeed, being exported. I’m still missing a couple of important pieces to be able to link and call this though:

  • First, this expects references to upstream symbols, MLIR specific ones. How do I link those in to produce a shared object?
  • Second, I need to know the ABI when calling this. Is there a way to emit something like a C struct for tensor<?x?xf64, #Dense2D> and tensor<?x?xf64, #CSR>?

These are provided by the following libraries: llvm-project/mlir/lib/ExecutionEngine/CMakeLists.txt at 14774ad59da4aaa8eb3da21c8c53b59834a1aeb7 · llvm/llvm-project · GitHub, llvm-project/mlir/lib/ExecutionEngine/CMakeLists.txt at 14774ad59da4aaa8eb3da21c8c53b59834a1aeb7 · llvm/llvm-project · GitHub. You may also need to link in the sparse runtime: llvm-project/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt at main · llvm/llvm-project · GitHub

@aartbik for this.

1 Like

Okay, I amended the command line to:

clear; time ./bin/mlir-opt '/Users/habbasi/Quansight/llvm-project/sddmm.mlir' --sparsifier="enable-arm-sve=true" -llvm-request-c-wrappers -convert-func-to-llvm | ./bin/mlir-translate --mlir-to-llvmir | ./bin/clang -O2 -c -x ir - -o a.o | ld a.o -o a.so -L./lib -lMLIRSparseTensorRuntime -L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem | ./bin/llvm-nm -C a.so

Relevant Stderr Output:

Undefined symbols for architecture arm64:
  "__mlir_ciface_expInsertF64", referenced from:
      _sddmm_kernel in a.o
  "__mlir_ciface_newSparseTensor", referenced from:
      _sddmm_kernel in a.o
  "__mlir_ciface_sparseCoordinates0", referenced from:
      _sddmm_kernel in a.o
  "__mlir_ciface_sparsePositions0", referenced from:
      _sddmm_kernel in a.o
  "__mlir_ciface_sparseValuesF64", referenced from:
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
  "_endLexInsert", referenced from:
      _sddmm_kernel in a.o
  "_main", referenced from:
      <initial-undefines>
  "_sparseLvlSize", referenced from:
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
      _sddmm_kernel in a.o
ld: symbol(s) not found for architecture arm64

And it still has some missing symbols, adding CRunnerUtils.cpp.o only increases the missing symbols but doesn’t get rid of any of them.

Edit: I linked in libMLIRSparse* but it didn’t seem to help – lSystem is the only one that got rid of a few needed symbols.

There is no direct ABI between the external world and MLIR sparse tensors (since there are two bespoke ways of storing MLIR sparse tensors currently). However, the ABI between the external world and MLIR sparse tensors is well-defined through the constituents buffers followed by an assemble operation (and dissassemble to return the sparse tensor to the outside world).

For example, to pass in CSR for a 1000x1000 sparse matrix with 100 nonzeros stored by coordinates and values, and a 1001 positions array can be done as follows.

 %csr = sparse_tensor.assemble (%csr_pos, %csr_index), %csr_data : 
         (tensor<1001xindex>, tensor<100xindex>), tensor<100xf64> to tensor<1000x1000xf64, #CSR>

The tensors eventually lower into memref buffer. Passing in these constituent buffers follows the dense ABI you are already familiar with.

As for as the sparse runtime library support methods are concerned, pulling in mlir_c_runner_utils should suffice. Note that you can toggle between –sparsifier="enable-runtime-library=false" and --sparsifier="enable-runtime-library=true" to generate code that uses direct codegen or codegen with support library support (the former creates less dependences on a runtime, except for some file I/O stuff).

1 Like

Thank you so much for all the help so far, it’s helped me get a lot further along. I’ve added the following “wrapper” so I can call via an ABI:

func.func @sddmm(
    // 1st argument, dense 2D
    %a_ten: tensor<?x?xf64>,
    // 2nd argument, likewise
    %b_ten: tensor<?x?xf64>,
    // Third argument; 2D CSR
    %c_pos: tensor<?xindex>,
    %c_index: tensor<?xindex>,
    %c_vals: tensor<?xf64>
) -> (tensor<?xindex>, tensor<?xindex>, tensor<?xf64>)
attributes { llvm.emit_c_interface } {
    %a = sparse_tensor.convert %a_ten : tensor<?x?xf64> to tensor<?x?xf64, #Dense2D>
    %b = sparse_tensor.convert %b_ten : tensor<?x?xf64> to tensor<?x?xf64, #Dense2D>
    // Define constants 0 and 1
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %d0 = tensor.dim %a, %c0 : tensor<?x?xf64, #Dense2D>
    %d1 = tensor.dim %b, %c1 : tensor<?x?xf64, #Dense2D>
    %c = sparse_tensor.assemble (%c_pos, %c_index), %c_vals : (tensor<?xindex>, tensor<?xindex>), tensor<?xf64> to tensor<?x?xf64, #CSR>

    %out = func.call @sddmm_kernel(%a, %b, %c) : (tensor<?x?xf64, #Dense2D>, tensor<?x?xf64, #Dense2D>, tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR>

    %out_pos, %out_index, %out_vals, %out_pos_len, %out_index_len, %out_vals_len =
        sparse_tensor.disassemble %out : tensor<?x?xf64, #CSR>
        out_lvls(%out_pos, %out_index : tensor<?xindex>, tensor<?xindex>)
        out_vals(%out_vals : tensor<?xf64>) ->
            (tensor<?xindex>, tensor<?xindex>), tensor<?xf64>, (index, index), index

    return %out_pos, %out_index, %out_vals : tensor<?xindex>, tensor<?xindex>, tensor<?xf64>
}

I’ve run up against the following which I have no clue how to solve; and the error message isn’t helpful:

/Users/habbasi/Quansight/llvm-project/sddmm.mlir:81:10: error: the sparse-tensor must have static shape
    %c = sparse_tensor.assemble (%c_pos, %c_index), %c_vals : (tensor<?xindex>, tensor<?xindex>), tensor<?xf64> to tensor<?x?xf64, #CSR>
         ^
/Users/habbasi/Quansight/llvm-project/sddmm.mlir:81:10: note: see current operation: %6 = "sparse_tensor.assemble"(%arg2, %arg3, %arg4) : (tensor<?xindex>, tensor<?xindex>, tensor<?xf64>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>

I understand that the shape cannot be inferred from just the levels; however, I do not see a way to specify the shape of a dynamic tensor in sparse_tensor.assemble in the docs.

Also one additional question – you mentioned the tensor/memref ABI – I’d request resourses on that, as this is my first foray into anything LLVM/MLIR.

I forgot to say that we also support the --sparse-assembler pass which basically converts all public methods of a module with sparse tensors as input/outputs automatically to your hand-rewritten version above. Currently, this flag follows the JAX conventions, while the --sparse-assembler="direct-out=true" variant supports PyTorch output conventions of TORCH-MLIR.

Due to JAX conventions, the constituent buffers must have static shape, and any change in size (essentially a change in number of nonzeros) will trigger a JIT compilation in that enviroment. So even though the MLIR Sparsifier works just as well with dynamic sizes as with static shapes, at the moment this particular assemble operation requires the static shape. So give that a try. If that works well for you, we can talk about relaxing the constraints on the ops.

Thanks a lot, I was able to get an initial version “working” (sort of), but it only “works” some of the time. My best guess is that it’s related to alignment; can you confirm that? I’ve gone all-out, so while there might be memory leaks in the code; I highly doubt the presence of use-after-frees or double-free issues.

Symptoms

  • Sometimes works, randomly, about 1/4th of the time, and returns correct results.
  • Sometimes dies.
  • Sometimes returns “bad results”, usually a sparse array which can be interpreted to have zero elements, or less than what should be there.
  • The performance, when it “works” is similar to a non-fused op (matmul + elementwise multiplication with CSR). This might just be me messing up the MLIR code itself though.

Code

Unfortunately ZIPs aren’t allowed, so posting code separately in collapsible sections.

sddmm.mlir
// Define the CSR format
#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0: dense, d1: compressed)
}>

// Define the 2D Dense format
#Dense2D = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0: dense, d1: dense)
}>

// Define the iterator trait for SDDMM
#trait_sddmm = {
    // Defines the equivalent of "index notation" for a GUFunc
    // For each input/output, define the input indices and output indices
    indexing_maps = [
        affine_map<(i, j, k) -> (i, k)>,    // first matmul arg
        affine_map<(i, j, k) -> (k, j)>,    // second matmul arg
        affine_map<(i, j, k) -> (i, j)>,    // sampling matrix
        affine_map<(i, j, k) -> (i, j)>     // output
    ],
    // For each index (i, j, k in this instance), define the kind of iteration
    // parallel -> We can iterate over this in parallel
    // reduction -> We perform a reduction over this dimension.
    iterator_types = ["parallel", "parallel", "reduction"],
    // An optional docstring.
    doc = "O(i,j) = C(i, j) * SUM_k A(i, k) * B(k, j)"
}

// Define the function
func.func @sddmm_kernel(
    // 1st argument, dense 2D
    %a: tensor<10000x10000xf64, #Dense2D>,
    // 2nd argument, likewise
    %b: tensor<10000x10000xf64, #Dense2D>,
    // Third argument; 2D CSR
    %c: tensor<10000x10000xf64, #CSR>
) -> tensor<10000x10000xf64, #CSR>
attributes { llvm.emit_c_interface } {
    // Define constants 0 and 1
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    // Get the dimension sizes of the output
    %d0 = tensor.dim %a, %c0 : tensor<10000x10000xf64, #Dense2D>
    %d1 = tensor.dim %b, %c1 : tensor<10000x10000xf64, #Dense2D>
    // Define the scratch output array that will be invalidated in CSR format
    %out = tensor.empty() : tensor<10000x10000xf64, #CSR>
    // Define the kernel
    // `linalg.generic` is like gufuncs in NumPy
    %sddmm = linalg.generic #trait_sddmm
    ins(%a, %b, %c : tensor<10000x10000xf64, #Dense2D>, tensor<10000x10000xf64, #Dense2D>, tensor<10000x10000xf64, #CSR>)
    outs(%out : tensor<10000x10000xf64, #CSR>) {
    // This defines the inner loop
    ^bb0(%a_one: f64, %b_one: f64, %c_one: f64, %init_one: f64):
        // This part makes %1 "empty"/"non-present" when %c_one is missing.
        %1 = sparse_tensor.binary %a_one, %c_one : f64, f64 to f64
        overlap={
        ^bb0(%arg0: f64, %arg1: f64):
            %0 = arith.mulf %arg0, %arg1 : f64
            sparse_tensor.yield %0 : f64
        }
        left={}
        right={}
        %2 = arith.mulf %1, %b_one : f64
        %3 = arith.addf %init_one, %2 : f64
        linalg.yield %3 : f64
    } -> tensor<10000x10000xf64, #CSR>

    return %sddmm : tensor<10000x10000xf64, #CSR>
}
compile.sh (meant for ARM Macs)
#!/usr/bin/env bash

set -euo pipefail

SCRIPT_LOCATION=$(dirname "$0")
LLVM_BIN="/Users/habbasi/Quansight/llvm-project/build/bin"

# Optimize the MLIR and emit certain utilities
"$LLVM_BIN/mlir-opt" "$SCRIPT_LOCATION/sddmm.mlir" \
    --sparse-assembler="direct-out=true" \
    --sparsifier="enable-runtime-library=false" \
    -llvm-request-c-wrappers \
    -lower-affine  -convert-scf-to-cf  -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts \
    -o "$SCRIPT_LOCATION/sddmm_opt.mlir"

# Convert MLIR to LLVM IR
"$LLVM_BIN/mlir-translate" "$SCRIPT_LOCATION/sddmm_opt.mlir" \
    --mlir-to-llvmir \
    -o "$SCRIPT_LOCATION/sddmm_opt.ll"

# Convert LLVM IR to optimized bytecode
"$LLVM_BIN/clang" "$SCRIPT_LOCATION/sddmm_opt.ll" \
    -flto -dynamiclib -fuse-ld=lld \
    -L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib \
    -O2 \
    -o "$SCRIPT_LOCATION/a.dylib"

# Read the symbol table
"$LLVM_BIN/llvm-nm" -C "$SCRIPT_LOCATION/a.dylib"
sddmm_test.py
import numpy as np
import ctypes
import typing
import scipy.sparse as sps
import weakref

c_intptr = np.ctypeslib.c_intp

class NumpyConversionMixin:
    @classmethod
    def from_numpy(cls: typing.Type[ctypes.Structure], arr: np.ndarray) -> typing.Self:
        ptr_type = cls._fields_[0][1]
        data_ptr = arr.ctypes.data_as(ptr_type)
        offset = 0
        sizes = cls._fields_[3][1](*arr.shape)
        strides = cls._fields_[4][1](*map(lambda s: s // arr.itemsize, arr.strides))
        ret = cls(data_ptr, data_ptr, offset, sizes, strides)

        CACHE[ret] = arr

        return ret
    
    def to_numpy(self: ctypes.Structure) -> np.ndarray:
        ret = np.ctypeslib.as_array(self.aligned, self.sizes)
        print(f"{self.offset=}, {tuple(self.sizes)=}, {tuple(self.strides)=}, {ret.view(np.intp)[0]=}, {ret.view(np.float64)[0]=}")

        CACHE[self] = ret

        return ret

    
    @classmethod
    def new(cls) -> typing.Self:
        ptr_type = cls._fields_[0][1]
        data_ptr = ptr_type()
        offset = 0
        size = cls._fields_[3][1](0)
        strides = cls._fields_[4][1](0)
        return cls(data_ptr, data_ptr, offset, size, strides)

class Memref1DDouble(ctypes.Structure, NumpyConversionMixin):
    _fields_ = [
        ("allocated", ctypes.POINTER(ctypes.c_double)),
        ("aligned", ctypes.POINTER(ctypes.c_double)),
        ("offset", c_intptr),
        ("sizes", c_intptr * 1),
        ("strides", c_intptr * 1),
    ]

    def __hash__(self) -> int:
        return hash(id(self))
        

class Memref1DIndex(ctypes.Structure, NumpyConversionMixin):
    _fields_ = [
        ("allocated", ctypes.POINTER(c_intptr)),
        ("aligned", ctypes.POINTER(c_intptr)),
        ("offset", c_intptr),
        ("sizes", c_intptr * 1),
        ("strides", c_intptr * 1),
    ]

    def __hash__(self) -> int:
        return hash(id(self))

CACHE = weakref.WeakKeyDictionary()

class CsrDouble(ctypes.Structure):
    _fields_ = [
        ("indptr", Memref1DIndex),
        ("indices", Memref1DIndex),
        ("data", Memref1DDouble),
    ]

    @classmethod
    def from_scipy_sparse(cls, arr: sps.csr_array | sps.csr_matrix) -> typing.Self:
        indptr = Memref1DIndex.from_numpy(arr.indptr.astype(np.int64))
        indices = Memref1DIndex.from_numpy(arr.indices.astype(np.int64))
        data = Memref1DDouble.from_numpy(arr.data.astype(np.float64))

        return cls(indptr=indptr, indices=indices, data=data)
    
    def to_scipy_sparse(self, shape=None) -> sps.csr_array:
        indptr = self.indptr.to_numpy()[:shape[0]+1]
        indices = self.indices.to_numpy()[indptr[0]:indptr[-1]]
        data = self.data.to_numpy()[indptr[0]:indptr[-1]]
        if indptr[0] != 0:
            indptr -= indptr[0]

        return sps.csr_array((data, indices, indptr), shape=shape)
    
    @classmethod
    def new(cls) -> typing.Self:
        return cls(indices=Memref1DIndex.new(), indptr=Memref1DIndex.new(), data=Memref1DDouble.new())


    def __hash__(self) -> int:
        return hash(id(self))


cdll = ctypes.CDLL("/Users/habbasi/Quansight/llvm-project/scratch/a.dylib")

def sddmm(a: np.ndarray, b: np.ndarray, c: sps.csr_array | sps.csr_matrix) -> sps.csr_array:
    out_shape = c.shape
    a = Memref1DDouble.from_numpy(a.flatten())
    b = Memref1DDouble.from_numpy(b.flatten())
    c = CsrDouble.from_scipy_sparse(c)
    out = CsrDouble.new()

    cdll._mlir_ciface_sddmm_kernel(
        ctypes.pointer(out),
        ctypes.pointer(a),
        ctypes.pointer(b),
        *[ctypes.pointer(getattr(c, f[0])) for f in CsrDouble._fields_],
    )

    return out.to_scipy_sparse(shape=out_shape)

if __name__ == "__main__":
    LEN = 10000
    arr1 = np.random.random((LEN, LEN))
    arr2 = np.random.random((LEN, LEN))
    arr3 = sps.random_array((LEN, LEN), format="csr", density=0.001)
    print(f"{len(arr3.indptr)=}, {len(arr3.indices)=}, {len(arr3.data)=}, {arr3.nnz=}, {arr3.has_canonical_format=}")

    ret = sddmm(arr1, arr2, arr3)
    print(str(ret))

@aartbik Upon further investigation, this definitely looks like a case of miscompilation for large shapes. Fixing density=0.001, and comparing results against a naïve implementation, I was able to determine on both ARM64 and AMD64 macOS (same script) that LEN = 110 works, LEN = 120 breaks (with a segfault on ARM and illegal instruction on x86). Going even higher reproduces the segfault on x86 as well without an illegal instruction. Of course, I recompiled the MLIR each time with the correct shapes, and edited the Python file manually.

The sddmm_opt.ll intermediate file is identical on both systems, so it’s probably an issue before that stage. The illegal instruction is probably an x86-specific issue. Unfortunately I lack the skills to say exactly where at this point.

One other thing that I noticed was that --sparse-assembler generated a signature with 9 memrefs – that should be 8 by my count (2 input dense + 3 for input csr + 3 for output csr), whereas --sparse-assembler="direct-out=true" did the correct thing (5 input pointers to memrefs plus 1 pointer to a struct of 3 memrefs).

I’m on LLVM commit 3c721b90d363bf73b78467f6e86c879235bac1b2 on both machines.

A couple of hints.

(1) There is really no need to use the “all-dense” annotation for dense tensors. Using this implies that the MLIR Sparsifier uses all its “sparse” paths to generate code for these tensors, even though eventually this just results in linearized storage (with some overhead for the [dis]assemble operations in this case). Unless you have a very specific reason for this, I suggest to use dense tensors for dense tensors :wink:

#Dense2D = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0: dense, d1: dense)
}>
...
tensor<10000x10000xf64, #Dense2D>

(2) Since you also return a sparse tensor (viz. sparse output value), there is a strong difference between --sparse-assembler with JAX semantics of passing in pre-allocated output buffers (yielding 2 + 3 + 3 = 8 input and 3 output) and --sparse-assembler="direct-out=true" with PyTorch semantics (yielding 2 + 3 = 5 input and 3 output, using dynamic shapes).

(3) Other than that, use the sparse_tensor.print to verify your tensor after assembly. For example, this example

#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0 : dense, d1 : compressed),
  posWidth = 32,
  crdWidth = 32
}>

  %csr_data = arith.constant dense<
       [ 1.0,  2.0,  3.0 ]
    > : tensor<3xf64>

    %csr_pos32 = arith.constant dense<
       [0, 1, 3]
    > : tensor<3xi32>

    %csr_index32 = arith.constant dense<
       [1, 0, 1]
    > : tensor<3xi32>
    %csr = sparse_tensor.assemble (%csr_pos32, %csr_index32), %csr_data : (tensor<3xi32>, tensor<3xi32>), tensor<3xf64>
                                           to tensor<2x2xf64, #CSR>

    sparse_tensor.print %csr : tensor<2x2xf64, #CSR>

yields:

---- Sparse Tensor ----
nse = 3
dim = ( 2, 2 )
lvl = ( 2, 2 )
pos[1] : ( 0, 1, 3,  )
crd[1] : ( 1, 0, 1,  )
values : ( 1, 2, 3,  )
----

I hope this helps!

Thanks! I was indeed misunderstanding the situation with the JAX calling convention. I’ve made the suggested changes to the MLIR/calling code by removing the format definition entirely for dense tensors. If I change the shapes to dynamic on all tensors and invoke mlir-opt with --sparse-assembler="direct-out=true" (which was slightly alluded to work with dynamic shapes) or --sparse-assembler, I get the following error:

./sddmm.mlir:25:1: error: the sparse-tensor must have static shape
func.func @sddmm_kernel(
^
./sddmm.mlir:25:1: note: see current operation: %0 = "sparse_tensor.assemble"(%arg2, %arg3, %arg4) : (tensor<?xindex>, tensor<?xindex>, tensor<?xf64>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>

However, everything compiles again if I switch to static shapes, and from performance measurements, it’s fusing correctly, and now also returning correct results, probably a weird issue with my memory management.

For anyone interested, here is my code:

sddmm.mlir
// Define the CSR format
#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0: dense, d1: compressed)
}>

// Define the iterator trait for SDDMM
#trait_sddmm = {
    // Defines the equivalent of "index notation" for a GUFunc
    // For each input/output, define the input indices and output indices
    indexing_maps = [
        affine_map<(i, j, k) -> (i, k)>,    // first matmul arg
        affine_map<(i, j, k) -> (k, j)>,    // second matmul arg
        affine_map<(i, j, k) -> (i, j)>,    // sampling matrix
        affine_map<(i, j, k) -> (i, j)>     // output
    ],
    // For each index (i, j, k in this instance), define the kind of iteration
    // parallel -> We can iterate over this in parallel
    // reduction -> We perform a reduction over this dimension.
    iterator_types = ["parallel", "parallel", "reduction"],
    // An optional docstring.
    doc = "O(i,j) = C(i, j) * SUM_k A(i, k) * B(k, j)"
}

// Define the function
func.func @sddmm_kernel(
    // 1st argument, dense 2D
    %a: tensor<10000x10000xf64>,
    // 2nd argument, likewise
    %b: tensor<10000x10000xf64>,
    // Third argument; 2D CSR
    %c: tensor<10000x10000xf64, #CSR>
) -> tensor<10000x10000xf64, #CSR>
attributes { llvm.emit_c_interface } {
    // Define constants 0 and 1
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    // Get the dimension sizes of the output
    %d0 = tensor.dim %a, %c0 : tensor<10000x10000xf64>
    %d1 = tensor.dim %b, %c1 : tensor<10000x10000xf64>
    // sparse_tensor.print %c : tensor<10000x10000xf64, #CSR>
    // Define the scratch output array that will be invalidated in CSR format
    %out = tensor.empty() : tensor<10000x10000xf64, #CSR>
    // Define the kernel
    // `linalg.generic` is like gufuncs in NumPy
    %sddmm = linalg.generic #trait_sddmm
    ins(%a, %b, %c : tensor<10000x10000xf64>, tensor<10000x10000xf64>, tensor<10000x10000xf64, #CSR>)
    outs(%out : tensor<10000x10000xf64, #CSR>) {
    // This defines the inner loop
    ^bb0(%a_one: f64, %b_one: f64, %c_one: f64, %init_one: f64):
        // This part makes %1 "empty"/"non-present" when %c_one is missing.
        %1 = arith.mulf %a_one, %b_one : f64
        %2 = sparse_tensor.binary %1, %c_one : f64, f64 to f64
        overlap={
        ^bb0(%arg0: f64, %arg1: f64):
            %0 = arith.mulf %arg0, %arg1 : f64
            sparse_tensor.yield %0 : f64
        }
        left={}
        right={}
        %3 = arith.addf %init_one, %2 : f64
        linalg.yield %3 : f64
    } -> tensor<10000x10000xf64, #CSR>

    // sparse_tensor.print %sddmm : tensor<10000x10000xf64, #CSR>

    return %sddmm : tensor<10000x10000xf64, #CSR>
}
compile.sh
#!/usr/bin/env bash

set -exuo pipefail

SCRIPT_LOCATION=$(dirname "$0")
LLVM_BUILD="/Users/habbasi/Quansight/llvm-project/build"
LLVM_BIN="$LLVM_BUILD/bin"

# Optimize the MLIR and emit certain utilities
"$LLVM_BIN/mlir-opt" "$SCRIPT_LOCATION/sddmm.mlir" \
    --sparse-assembler="direct-out=False" \
    --sparsifier="enable-runtime-library=false" \
    -llvm-request-c-wrappers \
    -lower-affine  -convert-scf-to-cf  -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts \
    -o "$SCRIPT_LOCATION/sddmm_opt.mlir"

# Convert MLIR to LLVM IR
"$LLVM_BIN/mlir-translate" "$SCRIPT_LOCATION/sddmm_opt.mlir" \
    --mlir-to-llvmir \
    -o "$SCRIPT_LOCATION/sddmm_opt.ll"

# Convert LLVM IR to optimized bytecode
"$LLVM_BIN/clang++" "$SCRIPT_LOCATION/sddmm_opt.ll" \
    -flto -dynamiclib -fuse-ld=lld \
    -L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib \
    -O2 -funsafe-math-optimizations -fno-trapping-math \
    -c -o "$SCRIPT_LOCATION/a.o"

"$LLVM_BIN/clang++" "$SCRIPT_LOCATION/a.o" \
    "$LLVM_BUILD/tools/mlir/lib/ExecutionEngine/CMakeFiles/mlir_runner_utils.dir/RunnerUtils.cpp.o" \
    "$LLVM_BUILD/tools/mlir/lib/ExecutionEngine/CMakeFiles/mlir_c_runner_utils.dir/CRunnerUtils.cpp.o" \
    "$LLVM_BUILD/tools/mlir/lib/ExecutionEngine/CMakeFiles/mlir_float16_utils.dir/Float16bits.cpp.o" \
    -flto -dynamiclib -fuse-ld=lld \
    -L/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib \
    -O2 -funsafe-math-optimizations -fno-trapping-math \
    -o "$SCRIPT_LOCATION/a.dylib"

# Read the symbol table
"$LLVM_BIN/llvm-nm" -C "$SCRIPT_LOCATION/a.dylib"
sddmm_test.py
import numpy as np
import ctypes
import typing
import scipy.sparse as sps
import weakref
import time

c_intptr = np.ctypeslib.c_intp

def _get_fields(*, ndim: int, dtype: np.dtype):
    ptr_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))
    return [
        ("allocated", ptr_type),
        ("aligned", ptr_type),
        ("offset", c_intptr),
        ("sizes", c_intptr * ndim),
        ("strides", c_intptr * ndim),
    ]

PREVENT_GC_C2 = {}
PREVENT_GC_C = weakref.WeakValueDictionary()


class NumpyConversionMixin(ctypes.Structure):
    NDIM: typing.ClassVar[int]
    DTYPE: typing.ClassVar[np.dtype]

    @classmethod
    def from_numpy(cls: typing.Type[ctypes.Structure], arr: np.ndarray) -> typing.Self:
        ptr_type = cls._fields_[0][1]
        data_ptr = arr.ctypes.data_as(ptr_type)
        offset = 0
        sizes = cls._fields_[3][1](*arr.shape)
        strides = cls._fields_[4][1](*map(lambda s: s // arr.itemsize, arr.strides))
        ret = cls(data_ptr, data_ptr, offset, sizes, strides)
        return ret
    
    def to_numpy(self: ctypes.Structure) -> np.ndarray:
        """ CAUTION: Assumes self owns the data. """
        ret = np.ctypeslib.as_array(self.aligned, shape=tuple(self.sizes))
        PREVENT_GC_C[self] = ret
        PREVENT_GC_C2[self] = ret
        return ret
    
    @classmethod
    def new(cls, *, shape: int | tuple[int, ...]) -> typing.Self:
        return cls.from_numpy(np.empty(shape, dtype=cls.DTYPE))

    def __hash__(self) -> int:
        return hash(id(self))
    
    def __eq__(self, value) -> bool:
        return self is value
    
    def __del__(self) -> None:
        if PREVENT_GC_C2.pop(self, None) is not None:
            cdll.free(self.allocated)

class Memref2DDouble(NumpyConversionMixin):
    DTYPE = np.dtype(np.float64)
    NDIM = 2

    _fields_ = _get_fields(ndim=NDIM, dtype=DTYPE)

class Memref1DDouble(NumpyConversionMixin):
    DTYPE = np.dtype(np.float64)
    NDIM = 1

    _fields_ = _get_fields(ndim=NDIM, dtype=DTYPE)
        

class Memref1DIndex(NumpyConversionMixin):
    DTYPE = np.dtype(np.uintp)
    NDIM = 1

    _fields_ = _get_fields(ndim=NDIM, dtype=DTYPE)

class CsrDouble(ctypes.Structure):
    _fields_ = [
        ("indptr", Memref1DIndex),
        ("indices", Memref1DIndex),
        ("data", Memref1DDouble),
    ]

    @classmethod
    def from_scipy_sparse(cls, arr: sps.csr_array | sps.csr_matrix) -> typing.Self:
        indptr = Memref1DIndex.from_numpy(arr.indptr.astype(np.int64))
        indices = Memref1DIndex.from_numpy(arr.indices.astype(np.int64))
        data = Memref1DDouble.from_numpy(arr.data.astype(np.float64))

        return cls(indptr=indptr, indices=indices, data=data)
    
    def to_scipy_sparse(self, *, shape: tuple[int, int]) -> sps.csr_array:
        indptr = self.indptr.to_numpy()[:shape[0]+1]
        indices = self.indices.to_numpy()[indptr[0]:indptr[-1]]
        data = self.data.to_numpy()[indptr[0]:indptr[-1]]
        if indptr[0] != 0:
            indptr -= indptr[0]

        return sps.csr_array((data, indices, indptr), shape=shape)
    
    @classmethod
    def new(cls, *, shape: tuple[int, int], nnz: int) -> typing.Self:
        return cls(indptr=Memref1DIndex.new(shape=shape[0] + 1), indices=Memref1DIndex.new(shape=nnz), data=Memref1DDouble.new(shape=nnz))

    def __hash__(self) -> int:
        return hash(id(self))
    
    def __eq__(self, value: object) -> bool:
        return self is value


cdll = ctypes.CDLL("/Users/habbasi/Quansight/llvm-project/scratch/a.dylib")
cdll._mlir_ciface_sddmm_kernel.restype = None
cdll.free.restype = None

def debug_print(name: str, arr: np.ndarray | sps.csr_array) -> None:
    print(f"--- Printing {name} ---")
    if isinstance(arr, np.ndarray):
        print(f"{arr=}")
    
    elif isinstance(arr, sps.csr_array | sps.csr_matrix):
        print(f"{arr.indptr=}")
        print(f"{arr.indices=}")
        print(f"{arr.data=}")
    
    print(f"--- End Printing {name} ---")
    

def sddmm(a: np.ndarray, b: np.ndarray, c: sps.csr_array | sps.csr_matrix) -> sps.csr_array:
    out_shape = c.shape
    nnz = c.nnz
    a_ref = Memref2DDouble.from_numpy(a)
    b_ref = Memref2DDouble.from_numpy(b)
    c_indptr_tmp = c.indptr.astype(np.int64)
    c_indices_tmp = c.indices.astype(np.int64)

    c_intptr_ref = Memref1DIndex.from_numpy(c_indptr_tmp)
    c_indices_ref = Memref1DIndex.from_numpy(c_indices_tmp)
    c_data_ref = Memref1DDouble.from_numpy(c.data)

    out = CsrDouble()
    if not DIRECT_OUT:
        out_indptr = np.empty(shape=out_shape[0]+1, dtype=np.int64)
        out_indices = np.empty(shape=nnz, dtype=np.int64)
        out_data = np.empty(shape=nnz, dtype=np.float64)
        out_intptr_ref = Memref1DIndex.from_numpy(out_indptr)
        out_indices_ref = Memref1DIndex.from_numpy(out_indices)
        out_data_ref = Memref1DDouble.from_numpy(out_data)
        cdll._mlir_ciface_sddmm_kernel(
            ctypes.byref(out),
            ctypes.byref(a_ref),
            ctypes.byref(b_ref),
            ctypes.byref(c_intptr_ref), ctypes.byref(c_indices_ref), ctypes.byref(c_data_ref),
            ctypes.byref(out_intptr_ref), ctypes.byref(out_indices_ref), ctypes.byref(out_data_ref),
        )
        return sps.csr_array((out_data, out_indices, out_indptr), shape=out_shape)
    else:
        cdll._mlir_ciface_sddmm_kernel(
            ctypes.byref(out),
            ctypes.byref(a_ref),
            ctypes.byref(b_ref),
            ctypes.byref(c_intptr_ref), ctypes.byref(c_indices_ref), ctypes.byref(c_data_ref),
        )
        ret = out.to_scipy_sparse(shape=out_shape)
        return ret

if __name__ == "__main__":
    LEN = 10000
    DENSITY = 0.0001
    DIRECT_OUT = False
    arr1 = np.random.random((LEN, LEN))
    arr2 = np.random.random((LEN, LEN))
    arr3 = sps.random_array((LEN, LEN), format="csr", density=DENSITY)

    start = time.time_ns()
    actual = sddmm(arr1, arr2, arr3)
    print(f"Fused SDDMM took {(time.time_ns() - start) / 10 ** 9} s.")
    start = time.time_ns()
    expected = ((arr1 @ arr2) * arr3).asformat("csr")
    print(f"Non-fused SDDMM took {(time.time_ns() - start) / 10 ** 9} s.")
    
    np.testing.assert_equal(actual.indptr, expected.indptr)
    np.testing.assert_equal(actual.indices, expected.indices)
    np.testing.assert_allclose(actual.data, expected.data)
    print("Success!")

A couple of final questions/requests:

  1. Dynamic shapes on CSR don’t seem to work even with direct-out=True. Is that intended? Am I compiling it wrong, perhaps? Can I help/request for it to be added?
  2. The NumPy version is using parallelism to speed things up, which makes it hard to beat. However, in this particular instance, there is a simple way to speed up the dot product inside the SDDMM. Is emitting that kind of parallelism supported?
  3. The sparsity structure can just be copied in this instance without the need for sparse inserts. Can we perhaps emit that sort of code?

Small tip: don’t mark your own posting with a follow-up question as “Solution” of the thread. Chances are people won’t come back to read further anymore :wink:

The reason your SDDMM does not perform well is because you expressed it incorrectly, causing the sparse compiler to think that a whole new sparse matrix is created (with expand/compress acceleration, but still slower).

There are many examples in-tree of SDDMM that works. Here is my favorite (which yields efficient code and even maps elegantly to GPU; nevertheless, loop permutations compared to the default code sometimes yield better performance, depending on sparsity structure and target architecture):

  func.func @SDDMM(%args: tensor<?x?xf32, #CSR>,
                   %arga: tensor<?x?xf32>,
                   %argb: tensor<?x?xf32>) -> tensor<?x?xf32, #CSR> {
    %result = linalg.generic #trait_SDDMM
      ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%args: tensor<?x?xf32, #CSR>) {
        ^bb(%a: f32, %b: f32, %s: f32):
           %f0 = arith.constant 0.0 : f32
           %u = sparse_tensor.unary %s : f32 to f32
             present={
                ^bb0(%p: f32):
                  %mul = arith.mulf %a, %b : f32
                  sparse_tensor.yield %mul : f32
             }
             absent={}
           %r = sparse_tensor.reduce %s, %u, %f0 : f32 {
              ^bb0(%p: f32, %q: f32):
                %add = arith.addf %p, %q : f32
                sparse_tensor.yield %add : f32
            }
           linalg.yield %r : f32
      } -> tensor<?x?xf32, #CSR>
    return %result : tensor<?x?xf32, #CSR>
  }