Help needed reproducing a specific BLIS-like strategy

I’ve been trying to follow this paper to learn how to use MLIR to generate a BLIS-like strategy for matrix multiplication. However, I’m specifically trying to figure out how to do this with the linalg codegen strategy and I’m having a really hard time reproducing something that would be equivalent to some of the intermediate passes being done.

At the moment, I’m trying to reproduce the IR in section 2.6. I want to start out with the following MLIR file with a simple call to linalg.matmul

func @matmul(%arg0: memref<2088x2048xf64>,
             %arg1: memref<2048x2048xf64>,
             %arg2: memref<2088x2048xf64>) {
  linalg.matmul
    ins(%arg0, %arg1: memref<2088x2048xf64>,
                      memref<2048x2048xf64>)
   outs(%arg2: memref<2088x2048xf64>)
  return
}

and end up with something equivalent to the following

#map4 = affine_map<(d0) -> (256 * d0)>
#map5 = affine_map<(d0) -> (256 * d0 + 256)>
#map6 = affine_map<(d0) -> (64 * d0)>
#map7 = affine_map<(d0) -> (2088, 64 * d0 + 64)>
#map9 = affine_map<(d0) -> (8 * d0)>
#map10 = affine_map<(d0) -> (8 * d0 + 8)>
#map16 = affine_map<(d0) -> (16 * d0)>
#map17 = affine_map<(d0) -> (522, 16 * d0 + 16)>

func @matmul(%A: memref<2088x2048xf64>, %B: memref<2048x2048xf64>, %C: memref<2088x2048xf64>) {
    affine.for %arg3 = 0 to 8 {
      affine.for %arg4 = 0 to 33 {
        %0 = memref.alloc() : memref<64x256xf64>
        // Packing %A into a 64x256 buffer.
        affine.for %arg5 = #map6(%arg4) to min #map7(%arg4) {
          affine.for %arg6 = #map4(%arg3) to #map5(%arg3) {
            %1 = affine.load %A[%arg5, %arg6] : memref<2088x2048xf64>
            affine.store %1, %0[%arg4 * -64 + %arg5, %arg3 * -256 + %arg6] : memref<64x256xf64>
          }
        }
        affine.for %arg5 = 0 to 256 {
          %1 = memref.alloc() : memref<256x8xf64>
          // Packing %B into a 256x8 buffer.
          affine.for %arg6 = #map4(%arg3) to #map5(%arg3) {
            affine.for %arg7 = #map9(%arg5) to #map10(%arg5) {
              %2 = affine.load %B[%arg6, %arg7] : memref<2048x2048xf64>
              affine.store %2, %1[%arg3 * -256 + %arg6, %arg5 * -8 + %arg7] : memref<256x8xf64>
            }
          }
          affine.for %arg6 = #map16(%arg4) to min #map17(%arg4) {
            // This is multiplying a packed 64x256 LHS panel with a packed 256x8 RHS panel.
            affine.for %arg7 = 0 to 256 {
              affine.for %arg8 = 0 to 8 {
                affine.for %arg9 = 0 to 4 {
                  %2 = affine.load %0[%arg4 * -64 + %arg6 * 4 + %arg9, %arg7] : memref<64x256xf64>
                  %3 = affine.load %1[%arg7, %arg8] : memref<256x8xf64>
                  %4 = affine.load %C[%arg6 * 4 + %arg9, %arg5 * 8 + %arg8] : memref<2088x2048xf64>
                  %5 = mulf %2, %3 : f64
                  %6 = addf %4, %5 : f64
                  affine.store %6, %C[%arg6 * 4 + %arg9, %arg5 * 8 + %arg8] : memref<2088x2048xf64>
                }
              }
            }
          }
          memref.dealloc %1 : memref<256x8xf64>
        }
        memref.dealloc %0 : memref<64x256xf64>
      }
    }
    return
  }

I don’t care about vectorization or unrolling yet. Nor do I care about the specific dialect that the resulting IR uses. I would just like something that yields same or comparable performance via passes using linalg codegen strategy if possible.

I’ve tried many different things in attempt to reproduce this. And I get something that close-ish by using tiling, interchanges, and promotions similar to slide 133 here. However, the resulting performance is about 1/3 of what the above strategy yields.

Sorry if this is too broad a question. But if anyone has any pointers, that would be greatly appreciated. I can also post what I have tried. Although I’d like to wait and see if someone has a simple suggestion first and then we can go from there. I imagine there is someone out there that already knows how exactly to do this.

Thanks!

I am not clear which strategy you are trying atm but for such larger sizes you’d need packing (== padding + hoisting of padding).

This is available on the linalg on tensors transformation path (see e.g. the step-by-step IR generated by this example) https://github.com/google/iree-llvm-sandbox/blob/main/runners/test/test_matmul_f32_004.mlir ).

This will require “comprehensive bufferization” to land before the rest can be flushed to core.

I’m away for the week but I’ll play a little with python-based search when I come back and see where we are.

First let me clarify and say that I’m not necessarily interested in the performance, but rather how to use the tools to perform any transformation I want (whether it’s a good strategy or not). This is meant more for educational purposes than anything else.

With that in mind, I’m not quite sure what relevance the size has. I do think array packing is what I’m least clear about. I understand what it is, and I was able to figure out how to do it by hand. But I cannot figure out how to do it with linalg codegen strategy. I was under the impression that promote somewhat served that purpose. I could be wrong. It’s very hard to find any documentation that clearly states what some of these options do.
I’ll take a look at padding+hoisting (the latter another term that is not very clear to me and can’t find a definition of, but I can vaguely infer what it means).

Thanks, I’ll take a look at this and play around.

I’ll also work on posting a small example showing what I’ve tried so far.

Even though I have code for this, I’ll show my steps using mlir-opt with -test-linalg-codegen-strategy passes. This should be equivalent to what I have in code and I think it’s easier to communicate.

Let’s start with the outer most tiling. The author of the paper does an M=64, K=256 tiling with loop order (i,j,k) → (k,i,j). To achieve this with linalg codegen strategy I’ll run

mlir-opt -test-linalg-codegen-strategy="tile-sizes=64,0,256 tile-interchange=2,0,1" linalg-matmul.mlir

Note that I use a tile size J=0. This seems to work for not tiling a given loop. J=2048 also seems to work, and there are no differences between the two at this level. However, I don’t know if they are always equivalent. This yields

#map0 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map1 = affine_map<(d0) -> (64, -d0 + 2088)>
module  {
  func @matmul(%arg0: memref<2088x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2088x2048xf32>) {
    %c64 = constant 64 : index
    %c256 = constant 256 : index
    %c0 = constant 0 : index
    %c2088 = constant 2088 : index
    %c2048 = constant 2048 : index
    scf.for %arg3 = %c0 to %c2048 step %c256 {
      %0 = memref.subview %arg1[%arg3, 0] [256, 2048] [1, 1] : memref<2048x2048xf32> to memref<256x2048xf32, #map0>
      scf.for %arg4 = %c0 to %c2088 step %c64 {
        %1 = affine.min #map1(%arg4)
        %2 = memref.subview %arg0[%arg4, %arg3] [%1, 256] [1, 1] : memref<2088x2048xf32> to memref<?x256xf32, #map0>
        %3 = affine.min #map1(%arg4)
        %4 = memref.subview %arg2[%arg4, 0] [%3, 2048] [1, 1] : memref<2088x2048xf32> to memref<?x2048xf32, #map0>
        linalg.matmul ins(%2, %0 : memref<?x256xf32, #map0>, memref<256x2048xf32, #map0>) outs(%4 : memref<?x2048xf32, #map0>)
      }
    }
    return
  }
}

First, there seems to be a redundancy here. %1 and %3 are the same, so we only need one of them. I’m not aware of a pass that cleans this up. I thought maybe canonicalize would, but it didn’t. Maybe lowerings further down will remove redundancy, but is it possible that this can affect any kind of fusion that we might want to perform before going to a lower level dialect?

Obviously, there’s nothing here to fuse yet. So I’ll use the promote option since that will introduce linalg.copy ops. I’ll ignore the promote-full-tile-pad for now, as I don’t think it’s relevant to the question I want to ask. Restrong textplacing the above pass with

mlir-opt -test-linalg-codegen-strategy="tile-sizes=64,0,256 tile-interchange=2,0,1 promote" linalg-matmul.mlir

yields

#map0 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map1 = affine_map<(d0) -> (64, -d0 + 2088)>
#map2 = affine_map<(d0, d1) -> (d0 * 256 + d1)>
#map3 = affine_map<(d0, d1) -> (d0 * 2048 + d1)>
module  {
  func @matmul(%arg0: memref<2088x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2088x2048xf32>) {
    %c2088 = constant 2088 : index
    %c65536 = constant 65536 : index
    %c256 = constant 256 : index
    %c2097152 = constant 2097152 : index
    %c64 = constant 64 : index
    %c2048 = constant 2048 : index
    %c0 = constant 0 : index
    %c524288 = constant 524288 : index
    scf.for %arg3 = %c0 to %c2048 step %c256 {
      %0 = memref.subview %arg1[%arg3, 0] [256, 2048] [1, 1] : memref<2048x2048xf32> to memref<256x2048xf32, #map0>
      scf.for %arg4 = %c0 to %c2088 step %c64 {
        %1 = affine.min #map1(%arg4)
        %2 = memref.subview %arg0[%arg4, %arg3] [%1, 256] [1, 1] : memref<2088x2048xf32> to memref<?x256xf32, #map0>
        %3 = affine.min #map1(%arg4)
        %4 = memref.subview %arg2[%arg4, 0] [%3, 2048] [1, 1] : memref<2088x2048xf32> to memref<?x2048xf32, #map0>
        %5 = memref.alloc(%c65536) {alignment = 16 : i64} : memref<?xi8>
        %6 = memref.view %5[%c0][] : memref<?xi8> to memref<64x256xf32>
        %7 = memref.subview %6[0, 0] [%1, 256] [1, 1] : memref<64x256xf32> to memref<?x256xf32, #map2>
        %8 = memref.alloc(%c2097152) {alignment = 16 : i64} : memref<?xi8>
        %9 = memref.view %8[%c0][] : memref<?xi8> to memref<256x2048xf32>
        %10 = memref.subview %9[0, 0] [256, 2048] [1, 1] : memref<256x2048xf32> to memref<256x2048xf32, #map3>
        %11 = memref.alloc(%c524288) {alignment = 16 : i64} : memref<?xi8>
        %12 = memref.view %11[%c0][] : memref<?xi8> to memref<64x2048xf32>
        %13 = memref.subview %12[0, 0] [%3, 2048] [1, 1] : memref<64x2048xf32> to memref<?x2048xf32, #map3>
        // copy tile of A
        linalg.copy(%2, %7) : memref<?x256xf32, #map0>, memref<?x256xf32, #map2>
        // copy tile of B 
        linalg.copy(%0, %10) : memref<256x2048xf32, #map0>, memref<256x2048xf32, #map3>
        // copy tile of C 
        linalg.copy(%4, %13) : memref<?x2048xf32, #map0>, memref<?x2048xf32, #map3> 
        linalg.matmul ins(%7, %10 : memref<?x256xf32, #map2>, memref<256x2048xf32, #map3>) outs(%13 : memref<?x2048xf32, #map3>)
        // stores result of matmul into tile of C
        linalg.copy(%13, %4) : memref<?x2048xf32, #map3>, memref<?x2048xf32, #map0> 
        memref.dealloc %5 : memref<?xi8>
        memref.dealloc %8 : memref<?xi8>
        memref.dealloc %11 : memref<?xi8>
      }
    }
    return
  }
}

Let’s look at the loops corresponding to the sequence of linalg ops. I’ll use convert-linalg-to-loops on the tiled IR and just present the portion corresponding to the linalg ops

// copy tile of A
scf.for %arg5 = %c0 to %1 step %c1 {
  scf.for %arg6 = %c0 to %c256 step %c1 {
    %14 = memref.load %2[%arg5, %arg6] : memref<?x256xf32, #map0>
    memref.store %14, %7[%arg5, %arg6] : memref<?x256xf32, #map2>
  }
}
// copy tile of B
scf.for %arg5 = %c0 to %c256 step %c1 {
  scf.for %arg6 = %c0 to %c2048 step %c1 {
    %14 = memref.load %0[%arg5, %arg6] : memref<256x2048xf32, #map0>
    memref.store %14, %10[%arg5, %arg6] : memref<256x2048xf32, #map3>
  }
}
// copy tile of C
scf.for %arg5 = %c0 to %3 step %c1 {
  scf.for %arg6 = %c0 to %c2048 step %c1 {
    %14 = memref.load %4[%arg5, %arg6] : memref<?x2048xf32, #map0>
    memref.store %14, %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
  }
}
// matmul of tile of A and tile of B
scf.for %arg5 = %c0 to %1 step %c1 {
  scf.for %arg6 = %c0 to %c2048 step %c1 {
    scf.for %arg7 = %c0 to %c256 step %c1 {
      %14 = memref.load %7[%arg5, %arg7] : memref<?x256xf32, #map2>
      %15 = memref.load %10[%arg7, %arg6] : memref<256x2048xf32, #map3>
      %16 = memref.load %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
      %17 = mulf %14, %15 : f32
      %18 = addf %16, %17 : f32
      memref.store %18, %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
    }
  }
}
// store result in tile of C
scf.for %arg5 = %c0 to %3 step %c1 {
  scf.for %arg6 = %c0 to %c2048 step %c1 {
    %14 = memref.load %13[%arg5, %arg6] : memref<?x2048xf32, #map3>
    memref.store %14, %4[%arg5, %arg6] : memref<?x2048xf32, #map0>
  }
}

Now I wouldn’t claim that we necessarily SHOULD fuse any of these loops. But it should be possible. My “desired” IR that I presented in my original post looks like it fuses the packing of B into the matmul in some way (albeit maybe after another level of tiling). However, I’m not too concerned about getting that exactly yet.

My question is, how would I go about doing any kind of fusion at this level? From what I’ve seen so far, there are a couple options:

  • Lower to affine and use -affine-loop-fusion
    • This won’t work because the pass only works on a sequence of perfectly nested loops.
    • The loops corresponding to the sequence of linalg ops are perfectly nested, however they are inside two outer loops which makes the entire loop structure not perfectly nested.
    • Even if this did work, I think there’s a limitation. Say I wanted to fuse the outer loops of the copy of A and copy of C, or the copy of C with matmul. They have the same outer loop(s), but I think that the fact that the same upper bound value is stored in different variables would prevent them from being able to fuse (referring to %1 and %3 being assigned to affine.min #map1(%arg4)).
  • Don’t lower linalg ops and try to use one of the linalg fusion passes
    • Most of these seem to only exist in test passes so maybe they haven’t been fully developed yet
    • test-linalg-greedy-fusion
      • naively applying this to the tiled IR above leaves it unchanged
    • test-linalg-fusion-transform-patterns
      • same result
    • linalg-fusion-for-tensor-ops
      • same result and the tests make it seem like this is only supposed to work for specific ops anyway. I even tried converting to generic ops first.
    • test-linalg-tile-and-fuse
      • Although I’m not particularly concerned about another level of tiling yet, given that we would eventually do another tiling anyway, this option would suffce
      • Unfortunately this doesn’t work. It returns a blank IR. I assume it’s because the linalg ops are inside scf.for loops and the pass doesn’t look for such cases?
    • There may be more options I haven’t seen, but it seems like none of these will work after a tile pass since the op calls are inside loops

This is already a long post, so I’ll leave it at that for now. But in summary, I’m just currently looking for a way to fuse ops after performing a tiling.

Thanks

Hi @srcarroll,

They are not always equivalent: tile by 0 will skip the loop whereas tile by the problem size will create a new loop that may of may not be canonicalized later (it generally does get canonicalized in most simple cases).

Try -cse ?

Basically, in side-effecting buffer-land transformations are quite harder to perform. It is possible some of these transformations will slowly disappear. Instead for the foreseeable future we prefer performing transformations at the tensor level where SSA use-def chains are ubiquitous: this is where pad introduces linalg.tensor_padding operations. The main benefit this brings to the table is that hoisting such padding operations across multiple loops allows amortizing the padding/packing operations. If you use buffers and promotion, there is no solid mechanism in linalg to perform memory and alias analysis to hoist redundant copies (i.e. you end up doing O(n^3) packings when you could have done O(n^2)). In tensor-land + SSA use-def chains, all this just becomes available (almost) for free.

If you insist on lowering to loops over buffers before fusion then you are looking at trying to perform loop fusion. In that case you should indeed convert to affine and use affine loop fusion.

Linalg transformations are indeed higher-level: fusion is available as “tile and fuse” (i.e. rewrite an op as a a loop nest that computes tile-by-tile and copy/move the producer/consumer op for that tile inside the loop nest). Note that in the limit, one can get similar behavior to loop-fusion by tiling by 0, 1 or “problem size” along some of the dimensions. Transformations at this level of abstraction composes more nicely in the tensor SSA-value domain.

There is indeed no first-class pass to perform fusion at the linalg level in core besides linalg-fusion-for-tensor-ops which implements a greedy “maximally fuse into a linalg region”-heuristic. All other passes are only test passes to check that the transformation patterns apply as expected. The main reason is that pass-level mechanisms and global heuristics to control this type of transformation are usually brittle, quickly become very complex to maintain and are very sensitive to handcrafted heuristics. OTOH, we aim for a higher-level of control. For now, MLIR core provides the transformations that compose and canonicalize to allow producing the optimized code variants we want. IREE has pass pipelines with heuristics that build on top of these facilities +(@MaheshRavishankar and @ThomasRaoux for some relevant links). We are looking into op-level control + transformations + search as a starting point and driving this from python (see this first test).

If your goal is to get a good grasp of where linalg is going, I’d recommend looking at the transformations that operate on tensors and that are driven by linalg-tensor-codegen-strategy (e.g. here or here) to see where the gradient is pointing to.

Hello, it’s been a while. I have questions/comments to a couple of your replies

I can’t actually get this to work in the context I outlined before because affine loop fusion doesn’t work for the contents of a loop. So I wouldn’t be able to do multi-level tiling while fusing linalg.fill/copy with linalg.matmul

Similar to my previous comment, it doesn’t seem like you can do multi-level tile-and-fuse because linalgs tile-and-fuse doesn’t seem to work when linalg ops are inside other loops.

What do you suggest to achieve multi-level tile-and-fuse in general? Does it even make sense to do that?

Thanks

So I did a little more digging and looked at some test passes. I stumbled across -test-linalg-greedy-fusion. This almost does what I want. However, there seems to be a very bad bug. Either that, or since it’s a greedy strategy there is no guarantee that it will preserve the computation.

Suppose I start with (I was following the test examples for this format)

func @matmul(%arg0: memref<20x80xf32>,
             %arg1: memref<80x40xf32>,
             %arg2: memref<20x40xf32>) {
  %cst = constant 0.000000e+00 : f32
  %0 = memref.alloc() : memref<20x40xf32>
  linalg.fill(%0, %cst) : memref<20x40xf32>, f32
  linalg.matmul
    ins(%arg0, %arg1: memref<20x80xf32>,
                      memref<80x40xf32>)
    outs(%0: memref<20x40xf32>)
  return
}

I then tile matmul and apply greedy fusion by running

./bin/mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64" -test-linalg-greedy-fusion /path/to/mlir

This produces the following

#map0 = affine_map<(d0) -> (16, -d0 + 20)>
#map1 = affine_map<(d0) -> (-d0 + 20, 16)>
#map2 = affine_map<(d0) -> (32, -d0 + 40)>
#map3 = affine_map<(d0, d1)[s0] -> (d0 * 40 + s0 + d1)>
#map4 = affine_map<(d0) -> (-d0 + 40, 32)>
#map5 = affine_map<(d0) -> (64, -d0 + 80)>
#map6 = affine_map<(d0, d1)[s0] -> (d0 * 80 + s0 + d1)>
module  {
  func @matmul(%arg0: memref<20x80xf32>, %arg1: memref<80x40xf32>, %arg2: memref<20x40xf32>) {
    %cst = constant 0.000000e+00 : f32
    %c16 = constant 16 : index
    %c32 = constant 32 : index
    %c64 = constant 64 : index
    %c0 = constant 0 : index
    %c80 = constant 80 : index
    %c20 = constant 20 : index
    %c40 = constant 40 : index
    %0 = memref.alloc() : memref<20x40xf32>
    scf.for %arg3 = %c0 to %c20 step %c16 {
      %1 = affine.min #map0(%arg3)
      %2 = affine.min #map1(%arg3)
      scf.for %arg4 = %c0 to %c40 step %c32 {
        %3 = affine.min #map2(%arg4)
        %4 = memref.subview %0[%arg3, %arg4] [%1, %3] [1, 1] : memref<20x40xf32> to memref<?x?xf32, #map3>
        %5 = affine.min #map4(%arg4)
        %6 = memref.subview %0[%arg3, %arg4] [%2, %5] [1, 1] : memref<20x40xf32> to memref<?x?xf32, #map3>
        scf.for %arg5 = %c0 to %c80 step %c64 {
          %7 = affine.min #map5(%arg5)
          %8 = memref.subview %arg0[%arg3, %arg5] [%1, %7] [1, 1] : memref<20x80xf32> to memref<?x?xf32, #map6>
          %9 = memref.subview %arg1[%arg5, %arg4] [%7, %3] [1, 1] : memref<80x40xf32> to memref<?x?xf32, #map3>
          linalg.fill(%6, %cst) : memref<?x?xf32, #map3>, f32 
          linalg.matmul ins(%8, %9 : memref<?x?xf32, #map6>, memref<?x?xf32, #map3>) outs(%4 : memref<?x?xf32, #map3>)
        }
      }
    }
    return
  }
}

First of all, there is a redundancy here: %4 and %6 are the same subviews since
affine.min #map0(%arg3) = affine.min #map1(%arg3)
and
affine.min #map2(%arg4) = affine.min #map4(%arg4)
Second, the linalg.fill is placed inside the inner most loop. Thus on every iteration of the reduction loop, the corresponding values of the C matrix are refilled to 0, which is wrong.

Is this expected? Do you have suggestions to get something similar but preserve computation?
I realize that linalg.fill and the tiled linalg.matmul shouldn’t be fusable with the way I’m using this pass anyway. I’d imagine having to tile the linalg.fill in a compatible way, i.e. with tile-sizes=16,32. However, if I try to tile both independently and then use greedy-fuse, this does not fuse the loops for seemingly the same reason I mentioned in my previous post.

Thanks

The greedy fusion does this cause it does no dependency analysis. The LinalgTileAndFusePattern (here) should handle this correctly by putting this within the second loop instead of the innermost.

Thank you for your reply. I’ve already tried using this with the test-linalg-tile-and-fuse command line pass. There are a couple problems

  1. Suppose you have a linalg.fill followed by a linalg.matmul. The former has 2 loops and the latter 3 loops. You can only pass a single list of tile sizes, so one of the matmul loops cannot get tiled. It should be theoretically possible to fuse the 2 spatial loops while still tiling the reduction loop in the matmul as long as the loop ordering is compatible. Correct me if I’m wrong.

  2. This only works on a single level tiling, but it will return an empty IR if the pass is applied on already tiled linalg ops. In fact, it’s not specific to tiling. The same will happen when the target linalg ops are inside loops

Is there a way around either of these issues?

Yup. You need to apply the tiling again. Typically you would stage it through markers. So you first do tile + fuse and use a marker to set a marker on the linalg.matmul after the transformation is done. This would tile + fuse just the parallel loops (which are “compatible”). Then you use the marker as an anchor and just tiled the reduction loop of the the tiled linalg.matmul from the first step.
THere are no command line options/pass that does this automatically. Partially is that there needs to be a way to specify all of this sequence of transformations through command line in a generic fashion, and I dont know of a clean mechanism to do that.

The level of tiling shouldnt matter. As long as you have a sequence of operations within the same block this works fine (this is meant to work within scf ops which have a single block by construction). You can here use markers to control the sequence of tiling + fusion operations that you need.

Thanks for the speedy reply. What you suggest makes sense to me at a high level. However I’m very weak with mlir code so it might take me a while to apply your suggestion, but this points me in the right direction. Is there any relevant example code that you recommend to learn from? Thanks again

If you look at this test pass (here), it shows how the markers work. Each of the patterns has a marker that triggers the pattern application and once applied adds a different marker on the transformed op. You could use the same scheme to tile the matmul again.

Any idea how to make this work with LinalgCodegenStrategy? I see that it has an option for passing a FilterFunction to the various functions like tile, tileIf, promote, promoteIf, etc. But I see 0 examples of writing a FilterFunction. The markers in the test pass you showed me make sense when applying to patterns. However I can’t figure it out for codgen strategy. Thanks in advance