What I am trying to do is improving the performance of SGEMM by using MLIR dialect. And now I am facing a bottleneck.
In my implementation, I use 12x8 as the inner kernel (so taking 12 values from A, 8 values from B), and the previous implementation was written as the following code:
%26 = scf.for %arg8 = %c0 to %arg5 step %c8 iter_args(%arg9 = %84) -> (vector<12x8xf32>) {
%28 = addi %47, %arg8 : index
%30 = addi %49, %arg8 : index
%35 = memref.subview %arg1[%30, 0] [8, 12] [1, 1] : memref<8832x12xf32, #map4> to memref<8x12xf32, #map5>
%36 = memref.subview %arg2[%28, 0] [8, 8] [1, 1] : memref<141312x8xf32, #map1> to memref<8x8xf32, #map2>
%53 = vector.transfer_read %36[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%54 = vector.transfer_read %36[%c1, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%55 = vector.transfer_read %36[%c2, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%56 = vector.transfer_read %36[%c3, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%57 = vector.transfer_read %36[%c4, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%58 = vector.transfer_read %36[%c5, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%59 = vector.transfer_read %36[%c6, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%60 = vector.transfer_read %36[%c7, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%61 = vector.transfer_read %35[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%62 = vector.transfer_read %35[%c1, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%63 = vector.transfer_read %35[%c2, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%64 = vector.transfer_read %35[%c3, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%65 = vector.transfer_read %35[%c4, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%66 = vector.transfer_read %35[%c5, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%67 = vector.transfer_read %35[%c6, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%68 = vector.transfer_read %35[%c7, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%69 = vector.outerproduct %61, %53, %arg9 : vector<12xf32>, vector<8xf32>
%70 = vector.outerproduct %62, %54, %69 : vector<12xf32>, vector<8xf32>
%71 = vector.outerproduct %63, %55, %70 : vector<12xf32>, vector<8xf32>
%72 = vector.outerproduct %64, %56, %71 : vector<12xf32>, vector<8xf32>
%73 = vector.outerproduct %65, %57, %72 : vector<12xf32>, vector<8xf32>
%74 = vector.outerproduct %66, %58, %73 : vector<12xf32>, vector<8xf32>
%75 = vector.outerproduct %67, %59, %74 : vector<12xf32>, vector<8xf32>
%76 = vector.outerproduct %68, %60, %75 : vector<12xf32>, vector<8xf32>
scf.yield %76 : vector<12x8xf32>
}
This snippet above continuously calculates a 12x8 subview of output matrix C. The problem of the code is that the load instructions take too much time, and what I want is using calculation to cover the latency of load instructions.
Then I found a possible solution in MLIR: loop pipelining
func @long_liverange(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c10 = constant 10 : index
%cf = constant 1.0 : f32
scf.for %i0 = %c0 to %c10 step %c1 {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = addf %A_elem, %cf { __test_pipelining_stage__ = 4, __test_pipelining_op_order__ = 0 } : f32
memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 4, __test_pipelining_op_order__ = 1 } : memref<?xf32>
} { __test_pipelining_loop__ }
return
}
This code transforms the loop into the following style, It uses add instructions to cover the latency of load/store:
module {
func @long_liverange(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%cst = constant 1.000000e+00 : f32
%c2 = constant 2 : index
%c3 = constant 3 : index
%c6 = constant 6 : index
%c4 = constant 4 : index
%c9 = constant 9 : index
%c8 = constant 8 : index
%c7 = constant 7 : index
%0 = memref.load %arg0[%c0] : memref<?xf32>
%1 = memref.load %arg0[%c1] : memref<?xf32>
%2 = memref.load %arg0[%c2] : memref<?xf32>
%3 = memref.load %arg0[%c3] : memref<?xf32>
%4:4 = scf.for %arg2 = %c0 to %c6 step %c1 iter_args(%arg3 = %0, %arg4 = %1, %arg5 = %2, %arg6 = %3) -> (f32, f32, f32, f32) {
%9 = addf %arg3, %cst : f32
memref.store %9, %arg1[%arg2] : memref<?xf32>
%10 = addi %arg2, %c4 : index
%11 = memref.load %arg0[%10] : memref<?xf32>
scf.yield %arg4, %arg5, %arg6, %11 : f32, f32, f32, f32
}
%5 = addf %4#0, %cst : f32
memref.store %5, %arg1[%c6] : memref<?xf32>
%6 = addf %4#1, %cst : f32
memref.store %6, %arg1[%c7] : memref<?xf32>
%7 = addf %4#2, %cst : f32
memref.store %7, %arg1[%c8] : memref<?xf32>
%8 = addf %4#3, %cst : f32
memref.store %8, %arg1[%c9] : memref<?xf32>
return
}
}
So I use the method above and change my code into the following snippet:
%26 = scf.for %arg8 = %c0 to %arg5 step %c8 iter_args(%arg9 = %84) -> (vector<12x8xf32>) {
%28 = addi %47, %arg8 : index
%30 = addi %49, %arg8 : index
%35 = memref.subview %arg1[%30, 0] [8, 12] [1, 1] : memref<8832x12xf32, #map4> to memref<8x12xf32, #map5>
%36 = memref.subview %arg2[%28, 0] [8, 8] [1, 1] : memref<141312x8xf32, #map1> to memref<8x8xf32, #map2>
%53 = vector.transfer_read %36[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%54 = vector.transfer_read %35[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%55:3 = scf.for %arg10 = %c1 to %c7 step %c2 iter_args(%arg11 = %53, %arg12 = %54, %arg13 = %arg9) -> (vector<8xf32>, vector<12xf32> , vector<12x8xf32>) {
%56 = vector.outerproduct %arg12, %arg11, %arg13 : vector<12xf32>, vector<8xf32>
%57 = vector.transfer_read %36[%arg10, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%58 = vector.transfer_read %35[%arg10, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%59 = vector.outerproduct %58, %57, %56 : vector<12xf32>, vector<8xf32>
%60 = addi %arg10, %c1 : index
%61 = vector.transfer_read %36[%60, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%62 = vector.transfer_read %35[%60, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
scf.yield %61, %62, %59 : vector<8xf32>, vector<12xf32> , vector<12x8xf32>
}
%63 = vector.outerproduct %55#1, %55#0, %55#2 : vector<12xf32>, vector<8xf32>
%64 = vector.transfer_read %36[%c7, %c0], %cst_0 {in_bounds = [true]} : memref<8x8xf32, #map2>, vector<8xf32>
%65 = vector.transfer_read %35[%c7, %c0], %cst_0 {in_bounds = [true]} : memref<8x12xf32, #map5>, vector<12xf32>
%66 = vector.outerproduct %65, %64, %63 : vector<12xf32>, vector<8xf32>
scf.yield %66 : vector<12x8xf32>
}
However, the generated assembly code is not what I expect. There are so many redundant instructions. For example, there shouldn’t have stp since all store operations are done after the kernel. And the performance is worse than the one I given at the beginning.
.LBB0_7: // in Loop: Header=BB0_8 Depth=3
add x22, x17, x22, lsl #2
add x21, x14, x21, lsl #2
ldp q2, q0, [x22, #224]
ldp q1, q3, [x21, #336]
ldr q4, [x21, #368]
add x7, x7, #8
fmla v15.4s, v0.4s, v1.s[0]
fmla v5.4s, v2.4s, v1.s[0]
fmla v25.4s, v0.4s, v1.s[1]
fmla v30.4s, v2.4s, v1.s[1]
fmla v29.4s, v0.4s, v1.s[2]
fmla v31.4s, v2.4s, v1.s[2]
fmla v27.4s, v0.4s, v1.s[3]
fmla v26.4s, v2.4s, v1.s[3]
ldr q1, [sp, #144] // 16-byte Folded Reload
fmla v23.4s, v0.4s, v3.s[0]
fmla v28.4s, v2.4s, v3.s[0]
fmla v20.4s, v0.4s, v3.s[1]
fmla v1.4s, v0.4s, v4.s[0]
str q1, [sp, #144] // 16-byte Folded Spill
ldr q1, [sp, #128] // 16-byte Folded Reload
fmla v24.4s, v2.4s, v3.s[1]
fmla v21.4s, v0.4s, v3.s[2]
fmla v22.4s, v2.4s, v3.s[2]
fmla v1.4s, v2.4s, v4.s[0]
str q1, [sp, #128] // 16-byte Folded Spill
ldr q1, [sp, #112] // 16-byte Folded Reload
fmla v19.4s, v0.4s, v3.s[3]
fmla v18.4s, v2.4s, v3.s[3]
fmla v6.4s, v0.4s, v4.s[1]
fmla v1.4s, v2.4s, v4.s[1]
str q1, [sp, #112] // 16-byte Folded Spill
ldr q1, [sp, #96] // 16-byte Folded Reload
fmla v12.4s, v0.4s, v4.s[2]
fmla v13.4s, v2.4s, v4.s[2]
fmla v17.4s, v2.4s, v4.s[3]
fmla v1.4s, v0.4s, v4.s[3]
str q1, [sp, #96]
Moreover, the code is lowered to LLVM IR in the following way.
mlir-opt intermediate.mlir -lower-affine -convert-linalg-to-loops -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts > test.mlir
mlir-translate --mlir-to-llvmir test.mlir > test.ll; llc test.ll -O3 -filetype=obj -o test.o; llc test.ll --debugify-quiet -O3 -filetype=asm -o test.s; clang++ test.o -o exec -lmlir_runner_utils -lmlir_c_runner_utils; ./exec
Anyone interested in this topic?