Even though I have code for this, I’ll show my steps using mlir-opt with -test-linalg-codegen-strategy passes. This should be equivalent to what I have in code and I think it’s easier to communicate.
Let’s start with the outer most tiling. The author of the paper does an M=64, K=256 tiling with loop order (i,j,k) → (k,i,j). To achieve this with linalg codegen strategy I’ll run
mlir-opt -test-linalg-codegen-strategy="tile-sizes=64,0,256 tile-interchange=2,0,1" linalg-matmul.mlir
Note that I use a tile size J=0. This seems to work for not tiling a given loop. J=2048 also seems to work, and there are no differences between the two at this level. However, I don’t know if they are always equivalent. This yields
#map0 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map1 = affine_map<(d0) -> (64, -d0 + 2088)>
module {
func @matmul(%arg0: memref<2088x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2088x2048xf32>) {
%c64 = constant 64 : index
%c256 = constant 256 : index
%c0 = constant 0 : index
%c2088 = constant 2088 : index
%c2048 = constant 2048 : index
scf.for %arg3 = %c0 to %c2048 step %c256 {
%0 = memref.subview %arg1[%arg3, 0] [256, 2048] [1, 1] : memref<2048x2048xf32> to memref<256x2048xf32, #map0>
scf.for %arg4 = %c0 to %c2088 step %c64 {
%1 = affine.min #map1(%arg4)
%2 = memref.subview %arg0[%arg4, %arg3] [%1, 256] [1, 1] : memref<2088x2048xf32> to memref<?x256xf32, #map0>
%3 = affine.min #map1(%arg4)
%4 = memref.subview %arg2[%arg4, 0] [%3, 2048] [1, 1] : memref<2088x2048xf32> to memref<?x2048xf32, #map0>
linalg.matmul ins(%2, %0 : memref<?x256xf32, #map0>, memref<256x2048xf32, #map0>) outs(%4 : memref<?x2048xf32, #map0>)
}
}
return
}
}
First, there seems to be a redundancy here. %1 and %3 are the same, so we only need one of them. I’m not aware of a pass that cleans this up. I thought maybe canonicalize would, but it didn’t. Maybe lowerings further down will remove redundancy, but is it possible that this can affect any kind of fusion that we might want to perform before going to a lower level dialect?
Obviously, there’s nothing here to fuse yet. So I’ll use the promote option since that will introduce linalg.copy ops. I’ll ignore the promote-full-tile-pad for now, as I don’t think it’s relevant to the question I want to ask. Restrong textplacing the above pass with
mlir-opt -test-linalg-codegen-strategy="tile-sizes=64,0,256 tile-interchange=2,0,1 promote" linalg-matmul.mlir
yields
#map0 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map1 = affine_map<(d0) -> (64, -d0 + 2088)>
#map2 = affine_map<(d0, d1) -> (d0 * 256 + d1)>
#map3 = affine_map<(d0, d1) -> (d0 * 2048 + d1)>
module {
func @matmul(%arg0: memref<2088x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2088x2048xf32>) {
%c2088 = constant 2088 : index
%c65536 = constant 65536 : index
%c256 = constant 256 : index
%c2097152 = constant 2097152 : index
%c64 = constant 64 : index
%c2048 = constant 2048 : index
%c0 = constant 0 : index
%c524288 = constant 524288 : index
scf.for %arg3 = %c0 to %c2048 step %c256 {
%0 = memref.subview %arg1[%arg3, 0] [256, 2048] [1, 1] : memref<2048x2048xf32> to memref<256x2048xf32, #map0>
scf.for %arg4 = %c0 to %c2088 step %c64 {
%1 = affine.min #map1(%arg4)
%2 = memref.subview %arg0[%arg4, %arg3] [%1, 256] [1, 1] : memref<2088x2048xf32> to memref<?x256xf32, #map0>
%3 = affine.min #map1(%arg4)
%4 = memref.subview %arg2[%arg4, 0] [%3, 2048] [1, 1] : memref<2088x2048xf32> to memref<?x2048xf32, #map0>
%5 = memref.alloc(%c65536) {alignment = 16 : i64} : memref<?xi8>
%6 = memref.view %5[%c0][] : memref<?xi8> to memref<64x256xf32>
%7 = memref.subview %6[0, 0] [%1, 256] [1, 1] : memref<64x256xf32> to memref<?x256xf32, #map2>
%8 = memref.alloc(%c2097152) {alignment = 16 : i64} : memref<?xi8>
%9 = memref.view %8[%c0][] : memref<?xi8> to memref<256x2048xf32>
%10 = memref.subview %9[0, 0] [256, 2048] [1, 1] : memref<256x2048xf32> to memref<256x2048xf32, #map3>
%11 = memref.alloc(%c524288) {alignment = 16 : i64} : memref<?xi8>
%12 = memref.view %11[%c0][] : memref<?xi8> to memref<64x2048xf32>
%13 = memref.subview %12[0, 0] [%3, 2048] [1, 1] : memref<64x2048xf32> to memref<?x2048xf32, #map3>
// copy tile of A
linalg.copy(%2, %7) : memref<?x256xf32, #map0>, memref<?x256xf32, #map2>
// copy tile of B
linalg.copy(%0, %10) : memref<256x2048xf32, #map0>, memref<256x2048xf32, #map3>
// copy tile of C
linalg.copy(%4, %13) : memref<?x2048xf32, #map0>, memref<?x2048xf32, #map3>
linalg.matmul ins(%7, %10 : memref<?x256xf32, #map2>, memref<256x2048xf32, #map3>) outs(%13 : memref<?x2048xf32, #map3>)
// stores result of matmul into tile of C
linalg.copy(%13, %4) : memref<?x2048xf32, #map3>, memref<?x2048xf32, #map0>
memref.dealloc %5 : memref<?xi8>
memref.dealloc %8 : memref<?xi8>
memref.dealloc %11 : memref<?xi8>
}
}
return
}
}
Let’s look at the loops corresponding to the sequence of linalg ops. I’ll use convert-linalg-to-loops on the tiled IR and just present the portion corresponding to the linalg ops
// copy tile of A
scf.for %arg5 = %c0 to %1 step %c1 {
scf.for %arg6 = %c0 to %c256 step %c1 {
%14 = memref.load %2[%arg5, %arg6] : memref<?x256xf32, #map0>
memref.store %14, %7[%arg5, %arg6] : memref<?x256xf32, #map2>
}
}
// copy tile of B
scf.for %arg5 = %c0 to %c256 step %c1 {
scf.for %arg6 = %c0 to %c2048 step %c1 {
%14 = memref.load %0[%arg5, %arg6] : memref<256x2048xf32, #map0>
memref.store %14, %10[%arg5, %arg6] : memref<256x2048xf32, #map3>
}
}
// copy tile of C
scf.for %arg5 = %c0 to %3 step %c1 {
scf.for %arg6 = %c0 to %c2048 step %c1 {
%14 = memref.load %4[%arg5, %arg6] : memref<?x2048xf32, #map0>
memref.store %14, %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
}
}
// matmul of tile of A and tile of B
scf.for %arg5 = %c0 to %1 step %c1 {
scf.for %arg6 = %c0 to %c2048 step %c1 {
scf.for %arg7 = %c0 to %c256 step %c1 {
%14 = memref.load %7[%arg5, %arg7] : memref<?x256xf32, #map2>
%15 = memref.load %10[%arg7, %arg6] : memref<256x2048xf32, #map3>
%16 = memref.load %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
%17 = mulf %14, %15 : f32
%18 = addf %16, %17 : f32
memref.store %18, %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
}
}
}
// store result in tile of C
scf.for %arg5 = %c0 to %3 step %c1 {
scf.for %arg6 = %c0 to %c2048 step %c1 {
%14 = memref.load %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
memref.store %14, %4[%arg5, %arg6] : memref<?x2048xf32, #map0>
}
}
Now I wouldn’t claim that we necessarily SHOULD fuse any of these loops. But it should be possible. My “desired” IR that I presented in my original post looks like it fuses the packing of B into the matmul in some way (albeit maybe after another level of tiling). However, I’m not too concerned about getting that exactly yet.
My question is, how would I go about doing any kind of fusion at this level? From what I’ve seen so far, there are a couple options:
- Lower to affine and use -affine-loop-fusion
- This won’t work because the pass only works on a sequence of perfectly nested loops.
- The loops corresponding to the sequence of linalg ops are perfectly nested, however they are inside two outer loops which makes the entire loop structure not perfectly nested.
- Even if this did work, I think there’s a limitation. Say I wanted to fuse the outer loops of the copy of A and copy of C, or the copy of C with matmul. They have the same outer loop(s), but I think that the fact that the same upper bound value is stored in different variables would prevent them from being able to fuse (referring to %1 and %3 being assigned to affine.min #map1(%arg4)).
- Don’t lower linalg ops and try to use one of the linalg fusion passes
- Most of these seem to only exist in test passes so maybe they haven’t been fully developed yet
- test-linalg-greedy-fusion
- naively applying this to the tiled IR above leaves it unchanged
- test-linalg-fusion-transform-patterns
- linalg-fusion-for-tensor-ops
- same result and the tests make it seem like this is only supposed to work for specific ops anyway. I even tried converting to generic ops first.
- test-linalg-tile-and-fuse
- Although I’m not particularly concerned about another level of tiling yet, given that we would eventually do another tiling anyway, this option would suffce
- Unfortunately this doesn’t work. It returns a blank IR. I assume it’s because the linalg ops are inside scf.for loops and the pass doesn’t look for such cases?
- There may be more options I haven’t seen, but it seems like none of these will work after a tile pass since the op calls are inside loops
This is already a long post, so I’ll leave it at that for now. But in summary, I’m just currently looking for a way to fuse ops after performing a tiling.
Thanks