Hi, I have been playing around with SME lowering and it occurred to me we could probably highly improve performance by reusing allocated tiles.
Please consider the payload in llvm-project/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir at main · llvm/llvm-project · GitHub.
If we run simply the transform sequence, we end up with this output.
Post transform sequence output
#map = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
module {
func.func @matmul(%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>, %arg1: memref<?x?xf32, strided<[?, ?], offset: ?>>, %arg2: memref<?x?xf32, strided<[?, ?], offset: ?>>) {
%cst = arith.constant 0.000000e+00 : f32
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = memref.dim %arg0, %c0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%dim_0 = memref.dim %arg0, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%dim_1 = memref.dim %arg1, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
scf.for %arg3 = %c0 to %dim step %c4_vscale {
scf.for %arg4 = %c0 to %dim_1 step %c4_vscale {
scf.for %arg5 = %c0 to %dim_0 step %c1 {
%0 = affine.min #map(%arg3)[%c4_vscale, %dim]
%1 = affine.min #map(%arg4)[%c4_vscale, %dim_1]
%subview = memref.subview %arg0[%arg3, %arg5] [%0, 1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x1xf32, strided<[?, ?], offset: ?>>
%subview_2 = memref.subview %arg1[%arg5, %arg4] [1, %1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
%subview_3 = memref.subview %arg2[%arg3, %arg4] [%0, %1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = vector.create_mask %0 : vector<[4]xi1>
%subview_4 = memref.subview %subview[0, 0] [%0, 1] [1, 1] : memref<?x1xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
%3 = vector.transfer_read %subview_4[%c0], %cst, %2 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
%4 = vector.shape_cast %3 : vector<[4]xf32> to vector<[4]x1xf32>
%5 = vector.create_mask %1 : vector<[4]xi1>
%subview_5 = memref.subview %subview_2[0, 0] [1, %1] [1, 1] : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
%6 = vector.transfer_read %subview_5[%c0], %cst, %5 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
%7 = vector.create_mask %0, %1 : vector<[4]x[4]xi1>
%8 = vector.transfer_read %subview_3[%c0, %c0], %cst, %7 {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<[4]x[4]xf32>
%9 = vector.transpose %4, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
%10 = vector.extract %9[0] : vector<[4]xf32> from vector<1x[4]xf32>
%11 = vector.mask %7 { vector.outerproduct %10, %6, %8 {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
vector.transfer_write %11, %subview_3[%c0, %c0], %7 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
memref.copy %subview_3, %subview_3 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
}
}
}
%cast = memref.cast %arg2 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<*xf32>
call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
return
}
module attributes {transform.with_named_sequence} {
}
func.func private @printMemrefF32(memref<*xf32>)
}
We have tiled our matmul by [[4], [4], 1]
, following the last dimension (throughout the execution of the innermost loop), the outerproduct targets the same memory space of the result. Since outerproduct writes in an accumulator, we should be able to load a tile outside the innermost loop, apply outerproduct on this tile and store it only once when coming out of the innermost loop.
e.g.
scf.for %arg3 = %c0 to %dim step %c4_vscale {
scf.for %arg4 = %c0 to %dim_1 step %c4_vscale {
load ZaTile
scf.for %arg5 = %c0 to %dim_0 step %c1 {
load A
load B
MOPA A, B ZaTile
}
store ZaTile
Updating step 5 to run LICM and canonicalization allows to hoist out masking creations and result’s subview.
Updated transform sequence with licm call
transform.apply_patterns to %func {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
transform.apply_patterns.vector.lower_masks
transform.apply_patterns.vector.rank_reducing_subview_patterns
// Needs canonicalization to make subview loop independant
transform.apply_patterns.canonicalization
} : !transform.any_op
%for = transform.structured.match ops{["scf.for"]} in %bufferize
: (!transform.any_op) -> !transform.any_op
transform.apply_licm to %for : !transform.any_op
Outputs :
scf.for %arg3 = %c0 to %dim step %c4_vscale {
// hoisted through LICM
%0 = affine.min #map(%arg3)[%c4_vscale, %dim]
%1 = vector.create_mask %0 : vector<[4]xi1>
scf.for %arg4 = %c0 to %dim_1 step %c4_vscale {
// hoisted through LICM + canonicalize
%2 = affine.min #map(%arg4)[%c4_vscale, %dim_1]
%subview = memref.subview %arg2[%arg3, %arg4] [%0, %2] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%3 = vector.create_mask %2 : vector<[4]xi1>
%4 = vector.create_mask %0, %2 : vector<[4]x[4]xi1>
scf.for %arg5 = %c0 to %dim_0 step %c1 {
%subview_2 = memref.subview %arg0[%arg3, %arg5] [%0, 1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x1xf32, strided<[?, ?], offset: ?>>
%subview_3 = memref.subview %arg1[%arg5, %arg4] [1, %2] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
%subview_4 = memref.subview %subview_2[0, 0] [%0, 1] [1, 1] : memref<?x1xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
%5 = vector.transfer_read %subview_4[%c0], %cst, %1 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
%6 = vector.shape_cast %5 : vector<[4]xf32> to vector<[4]x1xf32>
%subview_5 = memref.subview %subview_3[0, 0] [1, %2] [1, 1] : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, #map1>
%7 = vector.transfer_read %subview_5[%c0], %cst, %3 {in_bounds = [true]} : memref<?xf32, #map1>, vector<[4]xf32>
// This should turn into a tile load
%8 = vector.transfer_read %subview[%c0, %c0], %cst, %4 {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<[4]x[4]xf32>
%9 = vector.transpose %6, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
%10 = vector.extract %9[0] : vector<[4]xf32> from vector<1x[4]xf32>
%11 = vector.mask %4 { vector.outerproduct %10, %7, %8 {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
// This should turn into a tile store
vector.transfer_write %11, %subview[%c0, %c0], %4 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
// Suspicious copy ?
memref.copy %subview, %subview : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
}
Now, the masked outerproduct does not yet write in place, nor have we the required SSA form to represent this. So let us lower-to-arm-sme. Here is a snippet of the output. bb21 holds the tile storage directly preceded by the loop to represent the horizontal storage then by the bb19 holding the MOPA itself preceded by bb18, loading the ZA tile. They are all part of the same BB (if we do not consider the loop for line-by-line tile load/store).
Post lower to ArmSME output.
^bb17(%33: index): // 2 preds: ^bb16, ^bb18
%34 = builtin.unrealized_conversion_cast %33 : index to i64
%35 = arith.cmpi slt, %33, %32 : index
cf.cond_br %35, ^bb18, ^bb19
^bb18: // pred: ^bb17
%36 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%37 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%38 = llvm.getelementptr %36[%37] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%39 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%40 = llvm.mul %34, %39 : i64
%41 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%42 = llvm.mul %0, %41 : i64
%43 = llvm.add %40, %42 : i64
%44 = llvm.getelementptr %38[%43] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%45 = arith.index_castui %33 : index to i32
"arm_sme.intr.ld1w.horiz"(%9, %44, %45) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
%46 = arith.addi %33, %c1 : index
cf.br ^bb17(%46 : index)
^bb19: // pred: ^bb17
%47 = vector.transpose %20, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
%48 = vector.extract %47[0] : vector<[4]xf32> from vector<1x[4]xf32>
"arm_sme.intr.mopa"(%4, %9, %48, %22) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
cf.br ^bb20(%c0 : index)
^bb20(%49: index): // 2 preds: ^bb19, ^bb21
%50 = builtin.unrealized_conversion_cast %49 : index to i64
%51 = arith.cmpi slt, %49, %32 : index
cf.cond_br %51, ^bb21, ^bb22
^bb21: // pred: ^bb20
%52 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%53 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%54 = llvm.getelementptr %52[%53] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%55 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%56 = llvm.mul %50, %55 : i64
%57 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%58 = llvm.mul %0, %57 : i64
%59 = llvm.add %56, %58 : i64
%60 = llvm.getelementptr %54[%59] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%61 = arith.index_castui %49 : index to i32
"arm_sme.intr.st1w.horiz"(%9, %60, %61) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
%62 = arith.addi %49, %c1 : index
cf.br ^bb20(%62 : index)
I was wondering if this was a problem you sorted out somehow. Do you have suggestions ? I think it should happen sometime during lower-to-arm-sme when we still have full tile store|Load ops. Although, it would have probably helped to keep it under scf, lowering to cf has its reasons cf. : [PSA] ArmSME lowering pipeline and tile allocation changes
Additionnally, There is a suspicious memref.copy
at the end of the first step. I manage to get rid of it with -buffer-deallocation-pipeline
. I am not exactly sure how it affects the performance or memory accesses.
Side note, -buffer-deallocation-pipeline sometimes generates some error: 'vector.mask' op expects only one operation to mask
depending on the order of execution, I’ll look into turning the right pattern into a MaskableOpPattern.
Final question, probably for @c-rhodes , I am unsure of the use of rank_reducing_subview_patterns. It is not run in the trA example. Does it help some pattern matching in the matmul case or is it simply memref optimisations ?
CC @banach-space , @MacDue .