Hi all,
I have been playing for a few days to implement the full BLAS expression of GEMM in Linalg, but I am hitting a wall and I was hoping for some help. For sake of clarity, this is the full GEMM expression:
C = alpha*A*B + beta*C
Where alpha
and beta
are scalars.
I was able to get the correct result by doing the matmul first (alpha*A*B
) and then the beta*C
element-wise. However, this sequence needs a temporary variable to store the matmul.
The most efficient way to do this is to first calculate beta*C
in place and then to run the A*B
calculation on C
, still in place. No matter how many attempts, I was not able to generate optimal and correct MLIR for this.
This is the Linalg code I tried (EDIT : In this code I am pretranposing matrix A
and I am ignoring the multiplication by alpha
, so what I am really doing is: C = trans(A)*B + beta*C
):
// mlir-opt --linalg-comprehensive-module-bufferize --convert-linalg-to-loops %s
!type_A = type tensor<2048x2048xf32>
!type_B = type tensor<2048x2048xf32>
!type_C = type tensor<2048x2048xf32>
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#scalar_map_0 = affine_map<(d0, d1) -> ()>
#scalar_map_1 = affine_map<(d0, d1, d2) -> ()>
#identity_map = affine_map<(d0, d1) -> (d0, d1)>
func @gemm(%A : !type_A {linalg.buffer_layout = #identity_map, linalg.inplaceable = false},
%B : !type_B {linalg.buffer_layout = #identity_map, linalg.inplaceable = false},
%C : !type_C {linalg.buffer_layout = #identity_map, linalg.inplaceable = true}, %alpha : f32, %beta : f32) -> !type_C {
// %1 = beta * C
%1 = linalg.generic {
indexing_maps = [#scalar_map_0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%beta : f32)
outs(%C: !type_C) {
^bb0(%be :f32, %c: f32):
%out = arith.mulf %be, %c : f32
linalg.yield %out : f32
} -> !type_C
// %2 = alpha*A*B + %1 = alpha*A*B + beta*C
%2 = linalg.generic
{indexing_maps = [ affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
#scalar_map_1,
affine_map<(m, n, k) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B, %alpha: !type_A, !type_B, f32)
outs(%1: !type_C) {
^bb0(%a: f32, %b: f32, %al : f32, %c: f32) :
%d = arith.mulf %a, %b: f32
%e = arith.addf %c, %d: f32
linalg.yield %e : f32
} -> !type_C
return %2 : !type_C
}
Which is not doing what I want. Indeed this is the result:
module {
func @gemm(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: f32, %arg4: f32) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
scf.for %arg5 = %c0 to %c2048 step %c1 {
scf.for %arg6 = %c0 to %c2048 step %c1 {
%0 = memref.load %arg2[%arg5, %arg6] : memref<2048x2048xf32>
%1 = arith.mulf %arg4, %0 : f32
memref.store %1, %arg2[%arg5, %arg6] : memref<2048x2048xf32>
}
}
scf.for %arg5 = %c0 to %c2048 step %c1 {
scf.for %arg6 = %c0 to %c2048 step %c1 {
scf.for %arg7 = %c0 to %c2048 step %c1 {
%0 = memref.load %arg0[%arg7, %arg5] : memref<2048x2048xf32>
%1 = memref.load %arg1[%arg7, %arg6] : memref<2048x2048xf32>
%2 = memref.load %arg2[%arg5, %arg6] : memref<2048x2048xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
memref.store %4, %arg2[%arg5, %arg6] : memref<2048x2048xf32>
}
}
}
return
}
}
The problem is on the inner loop. What this loops is basically doing is:
for i = 0:K{
C[i,j] += A[i,k]*A[k,j];
}
But this wrong, because we are accumulating beta
for K
times, so we end up with C= A*B + K*beta*C
. What I really want is:
tmp = 0;
for i = 0:K{
tmp += A[i,k]*A[k,j];
}
C[i,j] += tmp;
One way to achieve this is to create a temporary tensor %tmp
initialized to zero, use the tensor for the matmul and then adding another linalg.generic
that does %out = %C + %tmp .
For completeness, this what the code would look like:
// same maps as before
func @gemm(%A : !type_A {linalg.buffer_layout = #identity_map, linalg.inplaceable = false},
%B : !type_B {linalg.buffer_layout = #identity_map, linalg.inplaceable = false},
%C : !type_C {linalg.buffer_layout = #identity_map, linalg.inplaceable = true}, %alpha : f32, %beta : f32) -> !type_C {
%cst = arith.constant 0.0 : f32
%init = linalg.init_tensor [2048, 2048] : !type_C
%0 = linalg.fill(%cst, %C) : f32, !type_C -> !type_C
// %1 = beta * C
%1 = linalg.generic {
indexing_maps = [#scalar_map_0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%beta : f32)
outs(%C: !type_C) {
^bb0(%be :f32, %c: f32):
%out = arith.mulf %be, %c : f32
linalg.yield %out : f32
} -> !type_C
// %2 = alpha*A*B
%2 = linalg.generic
{indexing_maps = [ affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
#scalar_map_1,
affine_map<(m, n, k) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B, %alpha: !type_A, !type_B, f32)
outs(%0: !type_C) {
^bb0(%a: f32, %b: f32, %al : f32, %c: f32) :
%d = arith.mulf %a, %b: f32
%e = arith.addf %c, %d: f32
linalg.yield %e : f32
} -> !type_C
// %C = %C + %2 = beta*C + A*B
%3 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%2: !type_C)
outs(%1: !type_C) {
^bb0(%x :f32, %y: f32):
%out = arith.addf %x, %y : f32
linalg.yield %out : f32
} -> !type_C
return %3 : !type_C
}
But now I have again a temporary tensor I want to get rid of. One way to do this is to fuse the addition into the matmul, and this is where I am stuck.
No matter how many attempts, I was not able to achieve this fusion.
So, I have two/three questions (and a lot of gratitude for any answer ) :
- Is it possible to generate the optimal correct result by using only two linalg.generic operations?
- From what I understood @nicolasvasilache mentioned that this is not possible in Linalg . If this is the case, how hard would it be to add support for something like this (if it makes sense to add it at all)?
- If adding support for this is not viable, what is the right way to fuse the third
linalg.generic
into the matmul?