Sparse Compiler and GPU Code-Generation

It is time to continue the older sparse compiler thread into a new topical thread.

Even though the MLIR sparse compiler is obviously meant as a re-targetable tool to exploit sparsity, most development so far has been focused on generating sparse code that runs on a CPU. Even though CPUs are least used for ML problems, having the ability to easily exploit unstructured sparsity still makes this a viable approach for accelerating sparse problems with very high sparsities.

Recently, however, we started to look into accelerating the generated sparse code for GPUs as well, with a new focus on exploiting structured sparsity (for example block-sparsity and 2:4 sparsity). To this end, we started to develop a prototype GPU code-generator. It is extremely basic, uses a very simple memory passing between host end device, and does not yield much performance gains, yet. But we hope this is the first step towards a much better GPU code-generator, allowing for hybrid execution, with unstructured sparse running on the CPU and structured sparse accelerated on the GPU.

You can find the very first step (with some follow up revisions that define a working compiler pipeline, and an end-to-end example). This very primitive GPU code generator basically converts the outermost loop of the generated sparse kernel into threaded code. For example, something like this:

func.func @matvec(%A: tensor<?x?xf64, #CSR>, 
                  %x: tensor<?xf64>,
                  %y_in: tensor<?xf64>) -> tensor<?xf64> {
  %y_out = linalg.matvec
      ins(%A, %x: tensor<?x?xf64, #CSR>, tensor<?xf64>)
      outs(%y_in: tensor<?xf64>) -> tensor<?xf64>
  return %y_out : tensor<?xf64>
}

is sparsified and then made parallel as follows, where the parameters are host registered buffers that contain the sparse matrix in CSR format.

gpu.module @sparsekernels {
    gpu.func @kernel(%arg0: index,
                     %arg1: memref<?xf64>,
                     %arg2: memref<?xindex>,
                     %arg3: memref<?xindex>,
                     %arg4: memref<?xf64>,
                     %arg5: memref<?xf64>) kernel {
      %c1 = arith.constant 1 : index
      %0 = gpu.block_id  x
      %1 = gpu.block_dim  x 
      %2 = gpu.thread_id  x 
      %3 = gpu.grid_dim  x
      %4 = arith.muli %0, %1 : index
      %5 = arith.addi %4, %2 : index
      %6 = arith.muli %1, %3 : index
      scf.for %arg6 = %5 to %arg0 step %6 {
        %7 = memref.load %arg1[%arg6] : memref<?xf64>
        %8 = memref.load %arg2[%arg6] : memref<?xindex>
        %9 = arith.addi %arg6, %c1 : index
        %10 = memref.load %arg2[%9] : memref<?xindex>
        %11 = scf.for %arg7 = %8 to %10 step %c1 iter_args(%arg8 = %7) -> (f64) {
          %12 = memref.load %arg3[%arg7] : memref<?xindex>
          %13 = memref.load %arg4[%arg7] : memref<?xf64>
          %14 = memref.load %arg5[%12] : memref<?xf64>
          %15 = arith.mulf %13, %14 : f64
          %16 = arith.addf %arg8, %15 : f64
          scf.yield %16 : f64
        } {"Emitted from" = "linalg.generic"}
        memref.store %11, %arg1[%arg6] : memref<?xf64>
      }
      gpu.return
    }
  }

And, yes, before the experts chime in, this is typically not the best way to make SpMV parallel :slight_smile:
But, all basic blocks are now in place to further develop GPU code generation into something with higher performance, especially when focused on structured sparsity.

Your insights and ideas are welcomed here!
Stay tuned for updates!

10 Likes

Some more progress!

The sparse compiler now has two prototype strategies for generating CUDA:

  1. CUDA codegen: this converts sparsified code to CUDA threads
  2. CUDA libgen: this converts pre-sparsified code to cuSPARSE library calls

An example of the former was shown above. An example of the latter is illustrated below (note that I have extended the GPU dialect with cuSparse support, I will send that out for review shortly, since this may trigger some discussions on the proper way to represent this and whether async tokens are required; but the basic mechanism is ready to be deployed!).

 func.func @matvec(%A: tensor<?x?xf64, #SortedCOO>,
                    %x: tensor<?xf64>,
                    %y_in: tensor<?xf64>) -> tensor<?xf64> {
    %y_out = linalg.matvec
      ins(%A, %x: tensor<?x?xf64, #SortedCOO>, tensor<?xf64>)
      outs(%y_in: tensor<?xf64>) -> tensor<?xf64>
    return %y_out : tensor<?xf64>
  }

lowers directly into cuSPARSE:

    %16 = gpu.create_sparse_env
    %17 = gpu.create_coo %1, %2, %dim, %memref, %memref_2, %memref_5 : memref<?xindex>, memref<?xindex>, memref<?xf64>
    %18 = gpu.create_dn_vec %memref_8, %2 : memref<?xf64>
    %19 = gpu.create_dn_vec %memref_11, %1 : memref<?xf64>
    %20 = gpu.spmv_buffer_size %16, %17, %18, %19
    %21 = gpu.wait async
    %memref_13, %asyncToken_14 = gpu.alloc async [%21] (%20) : memref<?xi8>
    gpu.wait [%asyncToken_14]
    gpu.spmv %16, %17, %18, %19, %memref_13 : memref<?xi8>
    gpu.destroy_sp_mat %17
    gpu.destroy_dn_vec %18
    gpu.destroy_dn_vec %19
    gpu.destroy_sparse_env %16
3 Likes

The first revision with the gpu dialect changes and mlir_cuda_runtime wrappers to cuSPARSE is out for review in âš™ D150152 [gpu][sparse] add gpu ops for sparse matrix computations (I could use some help figuring out the CMake syntax for including the cuSparse headers and library).

I will send out a follow-up CL that uses this from the sparse compiler pipeline shortly so that the changes fall into the right context.

This revision âš™ D150170 [mlir][sparse][gpu] first implementation of the GPU libgen approach shows how the actual libgen path works in the sparse compiler pipeline. Currently only for matvec on COO with F64 data, but generalizing this to other kernels and data types is straightforward once the initial larger revisions have been accepted.

Lastly, I plan to add an end-to-end test to the stack of changes. After that, the code will be generalized and refined, followed by actual performance analysis of course.

And finally for today, one end-to-end example in âš™ D150172 [mlir][sparse][gpu] end-to-end integration test of GPU libgen approach. As stated before, lots of refinements will follow, but I would like to get consensus on the basic approach first.

Sometimes it is just fun to look back and see how far we have come with automatic sparsification in MLIR!

For example, thanks to great team work in the Google sparse compiler team, including @PeimingLiu and @K-Wu, proper design by folks such as @nicolasvasilache, @mehdi_amini, @matthias-springer, @ThomasRaoux, and thanks to external contributions from people like @jim22k, we can now express a very idiomatic form of SDDMM using a “spy” function combined with in-place update S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j) using the linalg and sparse tensor dialects as follows.

#trait_sampled_dense_dense = {
  indexing_maps = [
    affine_map<(i,j,k) -> (i,k)>,  // A
    affine_map<(i,j,k) -> (k,j)>,  // B
    affine_map<(i,j,k) -> (i,j)>   // S
  ],
  iterator_types = ["parallel", "parallel", "reduction"],
  doc = "S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j)"
}

 func.func @sparse_sampled_dd(%argA: tensor<8x8xf64>,
                               %argB: tensor<8x8xf64>,
                               %argS: tensor<8x8xf64, #CSR>)
                                   -> tensor<8x8xf64, #CSR> {
    %f0 = arith.constant 0.0 : f64
    %result = linalg.generic #trait_sampled_dense_dense
      ins(%argA, %argB: tensor<8x8xf64>, tensor<8x8xf64>)
      outs(%argS: tensor<8x8xf64, #CSR>) {
        ^bb(%a: f64, %b: f64, %s: f64):
           %u = sparse_tensor.unary %s : f64 to f64
             present={
                ^bb0(%p: f64):
                  %mul = arith.mulf %a, %b : f64
                  sparse_tensor.yield %mul : f64
             }
             absent={}
           %r = sparse_tensor.reduce %s, %u, %f0 : f64 {
              ^bb0(%p: f64, %q: f64):
                %add = arith.addf %p, %q : f64
                sparse_tensor.yield %add : f64
            }
           linalg.yield %r : f64
    } -> tensor<8x8xf64, #CSR>
    return %result : tensor<8x8xf64, #CSR>
  }

This semi-ring formulation of performing the spy-sampling and addition in-place is eventually recognized as a special operation in the GPU Dialect.

%15 = gpu.sddmm async [%asyncToken_18]
    %env, %dnTensor, %dnTensor_13, %spmat, %memref_17
    : memref<?xi8> into f64

This eventually maps to the cusparseSDDMM() method implemented by our friends at NVidia to take full advantage of a GPU.

Such a pleasure to see the concept of progressive lowering to bridge the dense world with the sparse world in action!

6 Likes

Thanks to the great work of @K-Wu, automatic sparsification in MLIR can now also exploit NVidia’s 2:4 structured sparsity on GPUs (through the cusparseLt library for the time being). A Linalg operation that performs a matrix multiplication on a 2:4 matrix and a dense matrix is eventually converted to something like this:

 %spMat, %asyncToken_4 = gpu.create_2to4_spmat async [%6] {PRUNE_AND_CHECK}
           %c16, %c16, %memref : memref<16x16xf16>
 ...
 %7 = gpu.spmm async [%asyncToken_14]
      %spMat, %dnTensor, %dnTensor_6, %memref_9, %memref_11, %memref_13
      : memref<?xi8>, memref<?xi8>, memref<?xi8> into f16

Subsequently, this is lowered into a wrapper method call that uses cusparseLT:

llvm.call @mgpuCuSparseLtSpMM(%94, %97, %100, %118, %130, %142, %91)
   : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr,
      !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()

Note that here the PRUNE_AND_CHECK attribute means that the input matrix (still in uncompressed dense form) is first pruned (with a sanity check) before the actual compress operation is performed (which converts the matrix into a compressed form with only the nonzero values and a bit indices array).

As a fun experiment, I ran this kernel D = A B + C with id(B) and zero(C) so that effectively you can see the effects of zeroing-out two values in a 1x4 strip in A to maximize the L1-norm of the resulting strip:

A orginal:

( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )

A-pruned (but not compressed yet):

( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )

During the computation, obviously, the compressed A would be used (with zeros completely omitted).

3 Likes

Do you find the cuSparse API for SpGEMM a bit intimidating? Not to worry, simply rely on MLIR’s sparsifier to generate the right cuSparse code for you to exploit a GPU.

Simply start with:

  // Computes C = A x B with A,B,C sparse CSR.
  func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
                       %B: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> {
    %init = bufferization.alloc_tensor() : tensor<8x8xf32, #CSR>
    %C = linalg.matmul
      ins(%A, %B: tensor<8x8xf32, #CSR>,
                  tensor<8x8xf32, #CSR>)
      outs(%init: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR>
    return %C: tensor<8x8xf32, #CSR>
  }

and then invoke MLIR’s sparsifier pipeline to get an IR using the GPU Dialect that will eventually make all these calls to e.g. cusparseSpGEMM_workEstimation() and cusparseSpGEMM_compute() just right for you!

    ...
    %desc, %asyncToken_26 = gpu.spgemm_create_descr async [%asyncToken_25]
    %bufferSzNew, %asyncToken_27 = gpu.spgemm_work_estimation_or_compute async [%asyncToken_26]{ WORK_ESTIMATION} %spmat, %spmat_16, %spmat_24, %desc, %c0, %memref_22 : f32 into memref<?xf32>
    %memref_28, %asyncToken_29 = gpu.alloc async [%asyncToken_27] (%bufferSzNew) : memref<?xi8>
    %bufferSzNew_30, %asyncToken_31 = gpu.spgemm_work_estimation_or_compute async [%asyncToken_29]{ WORK_ESTIMATION} %spmat, %spmat_16, %spmat_24, %desc, %bufferSzNew, %memref_28 : f32 into memref<?xi8>
    %bufferSzNew_32, %asyncToken_33 = gpu.spgemm_work_estimation_or_compute async [%asyncToken_31]{ COMPUTE} %spmat, %spmat_16, %spmat_24, %desc, %c0, %memref_22 : f32 into memref<?xf32>
    %memref_34, %asyncToken_35 = gpu.alloc async [%asyncToken_33] (%bufferSzNew_32) : memref<?xi8>
    %bufferSzNew_36, %asyncToken_37 = gpu.spgemm_work_estimation_or_compute async [%asyncToken_35]{ COMPUTE} %spmat, %spmat_16, %spmat_24, %desc, %bufferSzNew_32, %memref_34 : f32 into memref<?xi8>
    %rows, %cols, %nnz, %asyncToken_38 = gpu.spgemm_get_size async [%asyncToken_37] %spmat_24
    %memref_39, %asyncToken_40 = gpu.alloc async [%asyncToken_38] (%nnz) : memref<?xi32>
    %memref_41, %asyncToken_42 = gpu.alloc async [%asyncToken_40] (%nnz) : memref<?xf32>
    %22 = gpu.set_csr_pointers async [%asyncToken_42] %spmat_24, %memref_18, %memref_39, %memref_41 : memref<?xi32>, memref<?xi32>, memref<?xf32>
    %23 = gpu.spgemm_copy async [%22] %spmat, %spmat_16, %spmat_24, %desc : f32
    ...
2 Likes

Note that this posting showed how to express block sparsity with the new sparse tensor type syntax, and how it maps directly to NVidia’s BSR format. As a result, the GPU path through the MLIR Sparsifier can directly exploit the cuSparse kernels that support BSR, as shown below.

#BSR = #sparse_tensor.encoding<{
  map = (i, j) -> (
    i floordiv 2 : dense,
    j floordiv 2 : compressed,
    i mod 2 : dense,
    j mod 2 : dense)
}>

  func.func @SDDMM_block(%args: tensor<?x?xf32, #BSR>,
                         %arga: tensor<?x?xf32>,
                         %argb: tensor<?x?xf32>) -> tensor<?x?xf32, #BSR> {
    ...  
}

=>

... = llvm.call @mgpuCreateBsr(...)

=>

cusparseCreateBsr(...)

1 Like