This is possible in a limited case. A version of this is implemented here.
As @gysit mentioned there are two kinds of fusion in Linalg. One is the elementwise fusion and other is tile + fuse. The file I link to above does the elementwise fusion. It broadly has three different things
- Fuse
linalg.generic
with otherlinalg.generic
to create a newlinalg.generic
operation. This is akin to fusion where the fusion results in perfectly nested loops - Fuse
tensor.collapse_shape
→linalg.generic
andlinalg.generic
→tensor.expand_shape
. This is done by expanding the dimensionality of thelinalg.generic
op in both cases. THis is called “FuseByExpansion”. - Fuse
tensor.expand_shape
→linalg.generic
andlinalg.generic
→tensor.collapse_shape
. This is done by collapsing the dimensionality of thelinalg.generic
op in both cases. THis is called “FuseByCollapsing”. Here need to be careful that the fusion does not result in indexing maps that are not “projected permutations” cause that affects subsequent analysis + optimizations.
All these three are implemented in that file. There are some other patterns in that file as well that are meant to be deprecated. See this post for more details and background. I havent pushed on this as much as I would have liked. I implemented the patterns needed, but havent deprecated the things that should be because I am waiting to decouple its uses from IREE.
This is precisely what FuseByExpansion
implemented does. It is already used in IREE and works quite well. Improves ability to fuse. After some recent study in IREE, we found that we needed the “FuseByCollapsing” as well to propagate the reshapes further to the edges when they get “blocked” by certain operations (the discourse post I linked above has more details).
Gather for now is implemented as a linalg.generic
when it is lowered from say MHLO dialect in TF, or TOSA dialect (or from Torch-MLIR). Since its a linalg.generic
operation, it fuses with other operations the same way you would expect. We have seen cases in some BERT models were multiple gathers get fused because of the elementwise operation fusion.
I just provided some high level overview in my response. I can go into a lot of detail but that would take a really long time. I can give you more targeted info if you need, but from your questions I think a lot of the functionality you are looking for is already implemented in the ElementwiseFusion
approach to fusion in Linalg. If there is something missing, would love to know more about your use case here. Im only asking because I consider the kind of questions asked, and approaches you mention to match well with what we’ve tried in Linalg and used in IREE, so seems like there might be some synergy here.