Pass to convert some ops to sparse

I’ve got it in my mind that I want to create a pass which uses some heuristic to promote some parts of a tensor program to sparse and then experiment with the codegen (since for deployment, we often have constant weights and can make some easy decisions based on that). I’m kind of treating this as an intro project to the sparse side of the world.

Since all of this infra is brand new and never has been attempted to connect together, I thought I’d play it forward step by step. Maybe the journey will be useful as a sample/documentation in the end.

To get started, I wrote the following little sample program:

from mlir.ir import *
from mlir.dialects.builtin import *
from mlir.dialects.tosa import *
from mlir.passmanager import *
import mlir.dialects.sparse_tensor as st

import mlir.conversions

def sparse_tensor(shape, levels=None, ordering=None, dtype=None):
  rank = len(shape)
  if not levels:
    levels = [st.DimLevelType.compressed] * rank
  if not ordering:
    ordering = AffineMap.get_identity(rank)
  encoding = st.EncodingAttr.get(levels, ordering, 32, 32)
  return RankedTensorType.get(shape,
    dtype if dtype else F32Type.get(), encoding=encoding)

def dense_tensor(shape, dtype=None):
  return RankedTensorType.get(shape,
    dtype if dtype else F32Type.get())


def create_sample_fc_module():
  m = Module.create()
  with InsertionPoint(m.body):
    @FuncOp.from_py_func(
        dense_tensor([256, 1024]),
        sparse_tensor([64, 1024]),
        dense_tensor([64]))
    def fc(inputs, weights, bias):
      d0 = RankedTensorType(inputs.type).get_dim_size(0)
      d1 = RankedTensorType(weights.type).get_dim_size(0)
      result_type = dense_tensor([d0, d1])
      return FullyConnectedOp(
        result_type,
        input=inputs, weight=weights, bias=bias,
        quantization_info=None).result
  return m

with Context() as ctx, Location.unknown():
  m = create_sample_fc_module()
  print("// Input module")
  print(m)

  pm = PassManager.parse("func(tosa-to-linalg-on-tensors)")
  pm.run(m)

  print("\n\n// Post linalg conversion")
  print(m)

Which dutifully prints:

// Input module
module  {
  func @fc(%arg0: tensor<256x1024xf32>, %arg1: tensor<64x1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 32, indexBitWidth = 32 }>>, %arg2: tensor<64xf32>) -> tensor<256x64xf32> {
    %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<256x1024xf32>, tensor<64x1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 32, indexBitWidth = 32 }>>, tensor<64xf32>) -> tensor<256x64xf32>
    return %0 : tensor<256x64xf32>
  }
}

// Post linalg conversion
#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>
module  {
  func @fc(%arg0: tensor<256x1024xf32>, %arg1: tensor<64x1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 32, indexBitWidth = 32 }>>, %arg2: tensor<64xf32>) -> tensor<256x64xf32> {
    %0 = linalg.init_tensor [256, 64] : tensor<256x64xf32>
    %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<64xf32>) outs(%0 : tensor<256x64xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
      linalg.yield %arg3 : f32
    } -> tensor<256x64xf32>
    %cst = constant dense<[1, 0]> : tensor<2xi64>
    %2 = linalg.init_tensor [1024, 64] : tensor<1024x64xf32>
    %3 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<64x1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 32, indexBitWidth = 32 }>>) outs(%2 : tensor<1024x64xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
      linalg.yield %arg3 : f32
    } -> tensor<1024x64xf32>
    %4 = linalg.matmul ins(%arg0, %3 : tensor<256x1024xf32>, tensor<1024x64xf32>) outs(%1 : tensor<256x64xf32>) -> tensor<256x64xf32>
    return %4 : tensor<256x64xf32>
  }
}

First off, this has a couple of problems:

  • The lowering from tosa.fully_connected looks wrong to my eyes (at a minimum, I would have expected to see an addf somewhere for the bias vector).
  • The conversions do not do any propagation of the tensor encoding, which may or may not be what we want (but is almost certainly not thought through for this case).

And some style nits:

  • It would be really nice if the tensor encoding were pulled up as an attribute alias like the affine maps are. It is quite hard to read as-is.
  • I can see half a dozen things that should/could be better in the Python AP.

Would love it if folks who have worked on this could help me choose my own adventure here and discuss/highlight next steps. Any of you all interested in collaborating towards a worked example here?@aartbik @rsuderman @sjarus

3 Likes

If we call the input matrices A and B, with A dense and B sparse, then I read this IR as follows.

D = B^T    ; with D dense
return A D 

Without fusing the kernels, this will not exploit any sparsity. It really should express this without storing into dense D first (note that I am also working on sparse outputs, so that we could keep D sparse, but fusion will be much better in this case).

Makes sense, and this is one of those cases where someone has pre-factored the op into a column major form in an attempt to be helpful. In many real scenarios, this would fold into the constant weights, I think.

Generalizing and fusing via:

  pm = PassManager.parse("func(linalg-generalize-named-ops), func(linalg-fusion-for-tensor-ops)")
  pm.run(m)
  print("\n\n//Post fusion")
  print(m)

Now produces:

//Post fusion
#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
module  {
  func @fc(%arg0: tensor<256x1024xf32>, %arg1: tensor<64x1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 32, indexBitWidth = 32 }>>, %arg2: tensor<64xf32>) -> tensor<256x64xf32> {
    %0 = linalg.init_tensor [256, 64] : tensor<256x64xf32>
    %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<64xf32>) outs(%0 : tensor<256x64xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
      linalg.yield %arg3 : f32
    } -> tensor<256x64xf32>
    %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x1024xf32>, tensor<64x1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 32, indexBitWidth = 32 }>>) outs(%1 : tensor<256x64xf32>) {
    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
      %3 = mulf %arg3, %arg4 : f32
      %4 = addf %arg5, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<256x64xf32>
    return %2 : tensor<256x64xf32>
  }
}

Leaving aside what I think is a bug with the bias, this looks better, right?

Yes, the sparse compiler kicks in, and happily generates the following code. Of course, the power of sparse compilers really surfaces by generating all combinations of dense/compressed row-/column-wise annotations for B and determining which one yields the best performance!

  func @fc(%arg0: tensor<256x1024xf32>, %arg1: !llvm.ptr<i8>, %arg2: tensor<64xf32>) -> tensor<256x64xf32> {
    %c256 = constant 256 : index
    %c0 = constant 0 : index
    %c1 = constant 1 : index
    %0 = linalg.init_tensor [256, 64] : tensor<256x64xf32>
    %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<64xf32>) outs(%0 : tensor<256x64xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
      linalg.yield %arg3 : f32
    } -> tensor<256x64xf32>
    %2 = memref.buffer_cast %arg0 : memref<256x1024xf32>
    %3 = call @sparsePointers32(%arg1, %c0) : (!llvm.ptr<i8>, index) -> memref<?xi32>
    %4 = call @sparseIndices32(%arg1, %c0) : (!llvm.ptr<i8>, index) -> memref<?xi32>
    %5 = call @sparsePointers32(%arg1, %c1) : (!llvm.ptr<i8>, index) -> memref<?xi32>
    %6 = call @sparseIndices32(%arg1, %c1) : (!llvm.ptr<i8>, index) -> memref<?xi32>
    %7 = call @sparseValuesF32(%arg1) : (!llvm.ptr<i8>) -> memref<?xf32>
    %8 = memref.buffer_cast %1 : memref<256x64xf32>
    %9 = memref.alloc() : memref<256x64xf32>
    linalg.copy(%8, %9) : memref<256x64xf32>, memref<256x64xf32> 
    scf.for %arg3 = %c0 to %c256 step %c1 {
      %11 = memref.load %3[%c0] : memref<?xi32>
      %12 = zexti %11 : i32 to i64
      %13 = index_cast %12 : i64 to index
      %14 = memref.load %3[%c1] : memref<?xi32>
      %15 = zexti %14 : i32 to i64
      %16 = index_cast %15 : i64 to index
      scf.for %arg4 = %13 to %16 step %c1 {
        %17 = memref.load %4[%arg4] : memref<?xi32>
        %18 = zexti %17 : i32 to i64
        %19 = index_cast %18 : i64 to index
        %20 = memref.load %5[%arg4] : memref<?xi32>
        %21 = zexti %20 : i32 to i64
        %22 = index_cast %21 : i64 to index
        %23 = addi %arg4, %c1 : index
        %24 = memref.load %5[%23] : memref<?xi32>
        %25 = zexti %24 : i32 to i64
        %26 = index_cast %25 : i64 to index
        %27 = memref.load %9[%arg3, %19] : memref<256x64xf32>
        %28 = scf.for %arg5 = %22 to %26 step %c1 iter_args(%arg6 = %27) -> (f32) {
          %29 = memref.load %6[%arg5] : memref<?xi32>
          %30 = zexti %29 : i32 to i64
          %31 = index_cast %30 : i64 to index
          %32 = memref.load %2[%arg3, %31] : memref<256x1024xf32>
          %33 = memref.load %7[%arg5] : memref<?xf32>
          %34 = mulf %32, %33 : f32
          %35 = addf %arg6, %34 : f32
          scf.yield %35 : f32
        }
        memref.store %28, %9[%arg3, %19] : memref<256x64xf32>
      }
    }
    %10 = memref.tensor_load %9 : memref<256x64xf32>
    return %10 : tensor<256x64xf32>
  }

Nice - do we have enough in tree yet to utter the incantation that does that?

This is all in tree already, but as usual, our “flag story” is elaborate. Below I show a typical flag combination (and you may need even more in some cases). I tried to explain what is what with comments. Obviously, in the long run, we want to provide convenience passes that do the right thing.

  --sparsification           ; triggers sparse compiler, can also set vector/parallel strategy
  --sparse-tensor-conversion ; converts sparse types to the opaque runtime pointer + lib calls
  --convert-linalg-to-loops 
  --convert-vector-to-scf 
  --convert-scf-to-std
  --convert-vector-to-scf      ; typical other dialect lowerings 
  --func-bufferize
  --tensor-constant-bufferize
  --tensor-bufferize
  --std-bufferize
  --finalizing-bufferize    ; good stuff! all this bufferizes as needed
  --convert-vector-to-llvm  ; needed if you vectorized sparse code
  --convert-std-to-llvm     ; finally ready to ship this off to LLVM IR!
2 Likes

Looks like the bias term is handled by using it as the initial value for the matmul result tensor. %1 is the broadcasted bias term, and the linalg.matmul has %1 as its outs operand (it accumulates “into” the bias value).

Ah, indeed - thanks! I was so fixated on “the missing addf” that my mind was down the wrong branch…