MLIR for arm SME : Reducing tile data transfers

Hi, I have been playing around with SME lowering and it occurred to me we could probably highly improve performance by reusing allocated tiles.

Please consider the payload in llvm-project/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir at main · llvm/llvm-project · GitHub.

If we run simply the transform sequence, we end up with this output.

Post transform sequence output
#map = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
module {
  func.func @matmul(%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>, %arg1: memref<?x?xf32, strided<[?, ?], offset: ?>>, %arg2: memref<?x?xf32, strided<[?, ?], offset: ?>>) {
    %cst = arith.constant 0.000000e+00 : f32
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = memref.dim %arg0, %c0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
    %dim_0 = memref.dim %arg0, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
    %dim_1 = memref.dim %arg1, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
    %vscale = vector.vscale
    %c4_vscale = arith.muli %vscale, %c4 : index
    scf.for %arg3 = %c0 to %dim step %c4_vscale {
      scf.for %arg4 = %c0 to %dim_1 step %c4_vscale {
        scf.for %arg5 = %c0 to %dim_0 step %c1 {
          %0 = affine.min #map(%arg3)[%c4_vscale, %dim]
          %1 = affine.min #map(%arg4)[%c4_vscale, %dim_1]
          %subview = memref.subview %arg0[%arg3, %arg5] [%0, 1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x1xf32, strided<[?, ?], offset: ?>>
          %subview_2 = memref.subview %arg1[%arg5, %arg4] [1, %1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
          %subview_3 = memref.subview %arg2[%arg3, %arg4] [%0, %1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
          %2 = vector.create_mask %0 : vector<[4]xi1>
          %subview_4 = memref.subview %subview[0, 0] [%0, 1] [1, 1] : memref<?x1xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
          %3 = vector.transfer_read %subview_4[%c0], %cst, %2 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
          %4 = vector.shape_cast %3 : vector<[4]xf32> to vector<[4]x1xf32>
          %5 = vector.create_mask %1 : vector<[4]xi1>
          %subview_5 = memref.subview %subview_2[0, 0] [1, %1] [1, 1] : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
          %6 = vector.transfer_read %subview_5[%c0], %cst, %5 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
          %7 = vector.create_mask %0, %1 : vector<[4]x[4]xi1>
          %8 = vector.transfer_read %subview_3[%c0, %c0], %cst, %7 {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<[4]x[4]xf32>
          %9 = vector.transpose %4, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
          %10 = vector.extract %9[0] : vector<[4]xf32> from vector<1x[4]xf32>
          %11 = vector.mask %7 { vector.outerproduct %10, %6, %8 {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
          vector.transfer_write %11, %subview_3[%c0, %c0], %7 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
          memref.copy %subview_3, %subview_3 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
        }
      }
    }
    %cast = memref.cast %arg2 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<*xf32>
    call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
    return
  }
  module attributes {transform.with_named_sequence} {
  }
  func.func private @printMemrefF32(memref<*xf32>)
}

We have tiled our matmul by [[4], [4], 1], following the last dimension (throughout the execution of the innermost loop), the outerproduct targets the same memory space of the result. Since outerproduct writes in an accumulator, we should be able to load a tile outside the innermost loop, apply outerproduct on this tile and store it only once when coming out of the innermost loop.

e.g.

scf.for %arg3 = %c0 to %dim step %c4_vscale {
  scf.for %arg4 = %c0 to %dim_1 step %c4_vscale {
    load ZaTile
    scf.for %arg5 = %c0 to %dim_0 step %c1 {
      load A
      load B
      MOPA A, B ZaTile
    }
    store ZaTile

Updating step 5 to run LICM and canonicalization allows to hoist out masking creations and result’s subview.

Updated transform sequence with licm call
    transform.apply_patterns to %func {
      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
      transform.apply_patterns.vector.lower_masks
      transform.apply_patterns.vector.rank_reducing_subview_patterns
      // Needs canonicalization to make subview loop independant
      transform.apply_patterns.canonicalization
    } : !transform.any_op

    %for = transform.structured.match ops{["scf.for"]} in %bufferize
      : (!transform.any_op) -> !transform.any_op
    transform.apply_licm to %for : !transform.any_op

Outputs :

scf.for %arg3 = %c0 to %dim step %c4_vscale {
   // hoisted through LICM
   %0 = affine.min #map(%arg3)[%c4_vscale, %dim]
   %1 = vector.create_mask %0 : vector<[4]xi1>
   scf.for %arg4 = %c0 to %dim_1 step %c4_vscale {
     // hoisted through LICM + canonicalize
     %2 = affine.min #map(%arg4)[%c4_vscale, %dim_1]
     %subview = memref.subview %arg2[%arg3, %arg4] [%0, %2] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
     %3 = vector.create_mask %2 : vector<[4]xi1>
     %4 = vector.create_mask %0, %2 : vector<[4]x[4]xi1>
     scf.for %arg5 = %c0 to %dim_0 step %c1 {
       %subview_2 = memref.subview %arg0[%arg3, %arg5] [%0, 1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x1xf32, strided<[?, ?], offset: ?>>
       %subview_3 = memref.subview %arg1[%arg5, %arg4] [1, %2] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
       %subview_4 = memref.subview %subview_2[0, 0] [%0, 1] [1, 1] : memref<?x1xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
       %5 = vector.transfer_read %subview_4[%c0], %cst, %1 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
       %6 = vector.shape_cast %5 : vector<[4]xf32> to vector<[4]x1xf32>
       %subview_5 = memref.subview %subview_3[0, 0] [1, %2] [1, 1] : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
       %7 = vector.transfer_read %subview_5[%c0], %cst, %3 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
       // This should turn into a tile load
       %8 = vector.transfer_read %subview[%c0, %c0], %cst, %4 {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<[4]x[4]xf32>
       %9 = vector.transpose %6, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
       %10 = vector.extract %9[0] : vector<[4]xf32> from vector<1x[4]xf32>
       %11 = vector.mask %4 { vector.outerproduct %10, %7, %8 {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
       // This should turn into a tile store
       vector.transfer_write %11, %subview[%c0, %c0], %4 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
       // Suspicious copy ?
       memref.copy %subview, %subview : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
     }

Now, the masked outerproduct does not yet write in place, nor have we the required SSA form to represent this. So let us lower-to-arm-sme. Here is a snippet of the output. bb21 holds the tile storage directly preceded by the loop to represent the horizontal storage then by the bb19 holding the MOPA itself preceded by bb18, loading the ZA tile. They are all part of the same BB (if we do not consider the loop for line-by-line tile load/store).

Post lower to ArmSME output.
^bb17(%33: index):  // 2 preds: ^bb16, ^bb18
  %34 = builtin.unrealized_conversion_cast %33 : index to i64
  %35 = arith.cmpi slt, %33, %32 : index
  cf.cond_br %35, ^bb18, ^bb19
^bb18:  // pred: ^bb17
  %36 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %37 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %38 = llvm.getelementptr %36[%37] : (!llvm.ptr, i64) -> !llvm.ptr, f32
  %39 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %40 = llvm.mul %34, %39 : i64
  %41 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %42 = llvm.mul %0, %41 : i64
  %43 = llvm.add %40, %42 : i64
  %44 = llvm.getelementptr %38[%43] : (!llvm.ptr, i64) -> !llvm.ptr, f32
  %45 = arith.index_castui %33 : index to i32
  "arm_sme.intr.ld1w.horiz"(%9, %44, %45) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
  %46 = arith.addi %33, %c1 : index
  cf.br ^bb17(%46 : index)
^bb19:  // pred: ^bb17
  %47 = vector.transpose %20, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
  %48 = vector.extract %47[0] : vector<[4]xf32> from vector<1x[4]xf32>
  "arm_sme.intr.mopa"(%4, %9, %48, %22) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
  cf.br ^bb20(%c0 : index)
^bb20(%49: index):  // 2 preds: ^bb19, ^bb21
  %50 = builtin.unrealized_conversion_cast %49 : index to i64
  %51 = arith.cmpi slt, %49, %32 : index
  cf.cond_br %51, ^bb21, ^bb22
^bb21:  // pred: ^bb20
  %52 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %53 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %54 = llvm.getelementptr %52[%53] : (!llvm.ptr, i64) -> !llvm.ptr, f32
  %55 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %56 = llvm.mul %50, %55 : i64
  %57 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  %58 = llvm.mul %0, %57 : i64
  %59 = llvm.add %56, %58 : i64
  %60 = llvm.getelementptr %54[%59] : (!llvm.ptr, i64) -> !llvm.ptr, f32
  %61 = arith.index_castui %49 : index to i32
  "arm_sme.intr.st1w.horiz"(%9, %60, %61) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
  %62 = arith.addi %49, %c1 : index
  cf.br ^bb20(%62 : index)

I was wondering if this was a problem you sorted out somehow. Do you have suggestions ? I think it should happen sometime during lower-to-arm-sme when we still have full tile store|Load ops. Although, it would have probably helped to keep it under scf, lowering to cf has its reasons cf. : [PSA] ArmSME lowering pipeline and tile allocation changes

Additionnally, There is a suspicious memref.copy at the end of the first step. I manage to get rid of it with -buffer-deallocation-pipeline. I am not exactly sure how it affects the performance or memory accesses.

Side note, -buffer-deallocation-pipeline sometimes generates some error: 'vector.mask' op expects only one operation to mask depending on the order of execution, I’ll look into turning the right pattern into a MaskableOpPattern.

Final question, probably for @c-rhodes , I am unsure of the use of rank_reducing_subview_patterns. It is not run in the trA example. Does it help some pattern matching in the matmul case or is it simply memref optimisations ?

CC @banach-space , @MacDue .

This exact lowering is what you get if you use IREE, which already does this hoisting, it is not included in the -test-lower-to-arm-sme pipeline as that was mainly intended for functional correctness tests (not optimal code).

Edit: In IREE it’s the --iree-codegen-optimize-tensor-insert-extract-slices pass (which occurs right after vectorization) which does the hoisting of the reads/writes. I think everything that pass does is available upstream (just maybe not packaged into a single pass).

1 Like

Also, I’d note that this hoisting is not specific to SME. You’d want this to happen for any code with an accumulator.