TLDR: Is there a way to make the rewriter match a pattern based on mlir
code, instead of using PDL or traversing the IR in a custom pass with an anchor-op? Can it be done without recompiling the compiler for each new pattern?
Hi,
I am working on generating accelerators and driver code for different (linalg) algorithms. As such linalg.generic
is great as it can capture the class I am interested on.
However, the accelerators are not derived from MLIR and the compiler only knows which operations they can cover if I create a specific match/rewrite for each accelerator_type
X dataflow
combination.
With the goal to enable the user to write a simple config file that allows driving their “complex” accelerator, I started playing with the idea of exposing enough accelerator parameters (size, instructions, operation, etc) as pass options or as an input config file that controls the codegen. All this works alright as long as I know the operation ahead of time (anchor-op
style).
What I would like is to request the user to write a single linalg.generic + trait (that includes the usual generic information + accelerator parameters) file as a config.mlir
to control the transformations.
Questions
Is it possible to pattern match based on this external .mlir
config file that has MLIR syntax?
Are there any examples in tree?
Considering that the file the user writes roundtrips ok with mlir-opt
, would it make sense for me to compile it separately and try to extract the operations information from its symbol table and use them on the target code?
Are there any PDL examples used with linalg.generic
patterns?
I appreciate any suggestions.
Thank you!
My example:
// Original MLIR
...
linalg.generic {indexing_maps = [#map5, #map6, #map5], iterator_types = ["parallel"]} ins(%104, %2 : memref<16xf32>, memref<f32>) outs(%166 : memref<16xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%513 = arith.addf %arg2, %arg3 : f32
linalg.yield %513 : f32
}
%167 = memref.alloc() : memref<16xf32>
linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel"]} ins(%166 : memref<16xf32>) outs(%167 : memref<16xf32>) {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
%513 = math.rsqrt %arg2 : f32
linalg.yield %513 : f32
}
%168 = memref.expand_shape %167 [[0, 1]] : memref<16xf32> into memref<1x16xf32>
%169 = memref.alloc() : memref<1x80x80x16xf32>
linalg.generic {indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%165, %168 : memref<1x80x80x16xf32>, memref<1x16xf32>) outs(%169 : memref<1x80x80x16xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%513 = arith.mulf %arg2, %arg3 : f32
linalg.yield %513 : f32
}
...
// User config file - used to match the pattern below, in the code above
#map2=...
#map4=...
#map2=...
#trait = {
indexing_maps = [#map2, #map4, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel"],
opcodes = {(..,..),..},
flow_pattern = ...,
etc = ...,
}
linalg.generic #trait ins(%A, %B : memref<?x?x?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?x?x?xf32>) {
^bb0(%0: f32, %1: f32, %2: f32):
%3 = arith.mulf %0, %1 : f32
linalg.yield %3 : f32
}
// Transformed code
...
linalg.generic {indexing_maps = [#map5, #map6, #map5], iterator_types = ["parallel"]} ins(%104, %2 : memref<16xf32>, memref<f32>) outs(%166 : memref<16xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%513 = arith.addf %arg2, %arg3 : f32
linalg.yield %513 : f32
}
%167 = memref.alloc() : memref<16xf32>
linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel"]} ins(%166 : memref<16xf32>) outs(%167 : memref<16xf32>) {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
%513 = math.rsqrt %arg2 : f32
linalg.yield %513 : f32
}
%168 = memref.expand_shape %167 [[0, 1]] : memref<16xf32> into memref<1x16xf32>
%169 = memref.alloc() : memref<1x80x80x16xf32>
// Replaced the last linalg.generic that matched the pattern with accelerator driver code.
accel.send %165 : memref<1x80x80x16xf32>
accel.send %168 : memref<1x16xf32>
accel.recv %169 : memref<1x80x80x16xf32>
...