Linalg packing for Matmul: who does the packing?

Hi,
In my quest to enhance the codegen produced for Matmul with MLIR (starting from Linalg), I reached 73% of the theoretical peak. Now I am trying to improve further and I am starting to dig down in the code generated. First focus has been about packing (following BLIS approach).

Basically, I have this matmul written in Linalg:

func @gemm(%A : memref<2048x2048xf32>, %B : memref<2048x2048xf32>, %C : memref<2048x2048xf32>) {

  linalg.matmul ins(%A, %B: memref<2048x2048xf32>, memref<2048x2048xf32>)
                     outs(%C: memref<2048x2048xf32>) 

  return 
}

And I apply the following strategy:

  • L1/L2/L3 tiling of (mc)256x(nc)128x(kc)64
  • Register tiling of (mr)8x(nr)8x(kr)64
  • No copy/fill tiling
  • promote
  • vectorize

The intermediate result that I get once I apply the Linalg passes is the following:

#map0 = affine_map<(d0, d1) -> (d0 * 64 + d1)>
#map1 = affine_map<(d0, d1) -> (d0 * 128 + d1)>
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map3 = affine_map<(d0, d1)[s0] -> (d0 * 64 + s0 + d1)>
#map4 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
module  {
  func @gemm(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) {
    %c65536 = constant 65536 : index
    %c32768 = constant 32768 : index
    %c128 = constant 128 : index
    %c64 = constant 64 : index
    %c2048 = constant 2048 : index
    %c8 = constant 8 : index
    %c256 = constant 256 : index
    %cst = constant 0.000000e+00 : f32
    %c0 = constant 0 : index
    %0 = alloca() : memref<8xvector<64xf32>>
    %1 = alloca() : memref<64xvector<8xf32>>
    %2 = alloca() : memref<8xvector<8xf32>>
    %3 = alloca() : memref<8xvector<8xf32>>
    %4 = alloc(%c65536) {alignment = 4096 : i64} : memref<?xi8>
    %5 = alloc(%c32768) {alignment = 4096 : i64} : memref<?xi8>
    %6 = std.view %4[%c0][] : memref<?xi8> to memref<256x64xf32>
    %7 = subview %6[0, 0] [256, 64] [1, 1] : memref<256x64xf32> to memref<256x64xf32, #map0>
    %8 = std.view %5[%c0][] : memref<?xi8> to memref<64x128xf32>
    %9 = subview %8[0, 0] [64, 128] [1, 1] : memref<64x128xf32> to memref<64x128xf32, #map1>
    %10 = subview %7[0, 0] [256, 64] [1, 1] : memref<256x64xf32, #map0> to memref<256x64xf32, #map0>
    %11 = subview %9[0, 0] [64, 128] [1, 1] : memref<64x128xf32, #map1> to memref<64x128xf32, #map1>
    %12 = alloc(%c2048) {alignment = 128 : i64} : memref<?xi8>
    %13 = alloc(%c2048) {alignment = 128 : i64} : memref<?xi8>
    %14 = alloc(%c256) {alignment = 128 : i64} : memref<?xi8>
    scf.for %arg3 = %c0 to %c2048 step %c256 {
      scf.for %arg4 = %c0 to %c2048 step %c128 {
        %15 = subview %arg2[%arg3, %arg4] [256, 128] [1, 1] : memref<2048x2048xf32> to memref<256x128xf32, #map2>
        scf.for %arg5 = %c0 to %c2048 step %c64 {
          %16 = subview %arg0[%arg3, %arg5] [256, 64] [1, 1] : memref<2048x2048xf32> to memref<256x64xf32, #map2>
          %17 = subview %arg1[%arg5, %arg4] [64, 128] [1, 1] : memref<2048x2048xf32> to memref<64x128xf32, #map2>
          %18 = subview %16[0, 0] [256, 64] [1, 1] : memref<256x64xf32, #map2> to memref<256x64xf32, #map2>
          linalg.copy(%18, %10) : memref<256x64xf32, #map2>, memref<256x64xf32, #map0>
          %19 = subview %17[0, 0] [64, 128] [1, 1] : memref<64x128xf32, #map2> to memref<64x128xf32, #map2>
          linalg.copy(%19, %11) : memref<64x128xf32, #map2>, memref<64x128xf32, #map1>
          scf.for %arg6 = %c0 to %c256 step %c8 {
            %20 = subview %7[%arg6, 0] [8, 64] [1, 1] : memref<256x64xf32, #map0> to memref<8x64xf32, #map3>
            scf.for %arg7 = %c0 to %c128 step %c8 {
              %21 = subview %15[%arg6, %arg7] [8, 8] [1, 1] : memref<256x128xf32, #map2> to memref<8x8xf32, #map2>
              %22 = subview %9[0, %arg7] [64, 8] [1, 1] : memref<64x128xf32, #map1> to memref<64x8xf32, #map4>
              affine.for %arg8 = 0 to 8 {
                %223 = vector.transfer_read %20[%arg8, %c0], %cst {masked = [false]} : memref<8x64xf32, #map3>, vector<64xf32>
                store %223, %0[%arg8] : memref<8xvector<64xf32>>
              }
              %23 = vector.type_cast %0 : memref<8xvector<64xf32>> to memref<vector<8x64xf32>>
              %24 = load %23[] : memref<vector<8x64xf32>>
              affine.for %arg8 = 0 to 64 {
                %223 = vector.transfer_read %22[%arg8, %c0], %cst {masked = [false]} : memref<64x8xf32, #map4>, vector<8xf32>
                store %223, %1[%arg8] : memref<64xvector<8xf32>>
              }
              %25 = vector.type_cast %1 : memref<64xvector<8xf32>> to memref<vector<64x8xf32>>
              %26 = load %25[] : memref<vector<64x8xf32>>
              affine.for %arg8 = 0 to 8 {
                %223 = vector.transfer_read %21[%arg8, %c0], %cst {masked = [false]} : memref<8x8xf32, #map2>, vector<8xf32>
                store %223, %2[%arg8] : memref<8xvector<8xf32>>
              }
              %27 = vector.type_cast %2 : memref<8xvector<8xf32>> to memref<vector<8x8xf32>>
              %28 = load %27[] : memref<vector<8x8xf32>>
              %29 = vector.transpose %24, [1, 0] : vector<8x64xf32> to vector<64x8xf32>
              // ****************************
              // a series of outer products
              // ****************************
              store %221, %222[] : memref<vector<8x8xf32>>
              affine.for %arg8 = 0 to 8 {
                %223 = load %3[%arg8] : memref<8xvector<8xf32>>
                vector.transfer_write %223, %21[%arg8, %c0] {masked = [false]} : vector<8xf32>, memref<8x8xf32, #map2>
              }
            }
          }
        }
      }
    }
    dealloc %14 : memref<?xi8>
    dealloc %13 : memref<?xi8>
    dealloc %12 : memref<?xi8>
    dealloc %5 : memref<?xi8>
    dealloc %4 : memref<?xi8>
    return
  }
}

My question is: in this MLIR code, where is the packing of A and B happening? There is a linalg.copy operation here:

linalg.copy(%19, %11) : memref<64x128xf32, #map2>, memref<64x128xf32, #map1>

And the the affine.for seems to load contiguous tiles into the registers:

affine.for %arg8 = 0 to 64 {
    %223 = vector.transfer_read %22[%arg8, %c0], %cst {masked = [false]} : memref<64x8xf32, #map4>, vector<8xf32>
    store %223, %1[%arg8] : memref<64xvector<8xf32>>
}

But I am not sure what operation is actually packing the 256x64 layout into the (256/8)x(64*8) layout. Sorry in advance if I am asking naive questions :slight_smile:

Thanks,
Giuseppe

For this part, you will need linalg on tensors transformations (padding and hoist padding) + comprehensive bufferization to be hooked up properly.

I’d recommend taking a look at this test: iree-llvm-sandbox/linalg_matmul.py at main · google/iree-llvm-sandbox · GitHub

It uses the experts.py in the same test dir.

These are used to drive the iree-llvm-sandbox/LinalgTensorCodegenDriver.cpp at main · google/iree-llvm-sandbox · GitHub

This is where some of us are experimenting out of tree with all batteries included.