This post is relevant to folks who use the transformation in Linalg dialect for fusion of elementwise operations in their compilation pipeline
Linalg dialect has a bunch of patterns that deal with fusion of elementwise operations (here). The primary of them are
- Fusion of
linalg.generic
operations that have producer → consumer relationships (here) - Fusion of
tensor.collapse_shape
→linalg.generic
(and converselylinalg.generic
→tensor.expand_shape
) by expanding the dimensionality of thelinalg.generic
operations.
These patterns together when executed in a fix point fuse elementwise operations and reshapes while pushing the reshape to the edges of the function , or before/after operations that cannot propagate the reshape operations. With these patterns one can end up with operations like
%gemm = linalg.matmul ... -> tensor<?x?xf32>
%1 = tensor.expand_shape %gemm [[0, 1], [2]] : tensor<?x?xf32> into tensor<?x42x?xf32>
%2 = linalg.generic ... ins(... %1 ...)
This affects fusing the linalg.matmul
with linalg.generic
later using tile + fuse. To address this issue a bunch of other patterns exist that try to handle this in one-off ways
- Fusion of reshape op with
linalg.generic
by linearization of the indexing maps (here). This was a very early attempt of doing fusion of reshape with elementwise operations. In general this leads to indexing maps that are not projected permutations that hamper subsequent optimizations (like tile + fuse). The only reason they are still around is for cases where unit-dimensions are folded (which happens a lot in ML models). These need to purged. - Push
tensor.expand_shape
operations pastlinalg.generic
operations (here. This has some of the right functionality, but is constrained in its application (and also does not work for ops with indexing semantics)
For the next iteration, I am planning to drop these two patterns above and replace this with a more general fusion of elementwise operations with reshape by collapsing dimensions. This pattern tries to collapse dimensions of the iteration space in the linalg.generic
to fuse with the reshape, while maintaining indexing maps to be projected permutations.
The flow of fusion of elementwise operations is meant to be
- Run to fixed point fusion of
linalg.generic
→linalg.generic
patterns along with the patterns that fusetensor.collapse_shape
→linalg.generic
/linalg.generic
→tensor.expand_shape
. - Run to fixed point fusion of
tensor.expand_shape
→linalg.generic
. Note that initially I dont plan to add the corollary pattern of fusinglinalg.generic
→tensor.collapse_shape
. The rationale here is that thetensor.expand_shape
→linalg.generic
pushes reshapes “down”, which is better for subsequent tile + fuse.linalg.generic
→tensor.expand_shape
pushes reshape “up” which is detrimental to tile + fuse transformations later on.
I have an initial patch of the fusion by collapsing dimensions up (⚙ D119365 [mlir][Linalg] Add pattern for folding reshape by collapsing.) . I am going to prototype the effects and continue to iterate on this in context of IREE (where I have access to whole models and can collect statistics on how these approaches perform). If these work out, the patterns mentioned above that are superseded by the fusion by collapsing, will be deprecated. Ill post my findings on this RFC as and when I have them.
FYI: @gysit @nicolasvasilache