RFC: Next iteration of fusion of elementwise operations

This post is relevant to folks who use the transformation in Linalg dialect for fusion of elementwise operations in their compilation pipeline

Linalg dialect has a bunch of patterns that deal with fusion of elementwise operations (here). The primary of them are

  1. Fusion of linalg.generic operations that have producer → consumer relationships (here)
  2. Fusion of tensor.collapse_shapelinalg.generic (and conversely linalg.generictensor.expand_shape) by expanding the dimensionality of the linalg.generic operations.

These patterns together when executed in a fix point fuse elementwise operations and reshapes while pushing the reshape to the edges of the function , or before/after operations that cannot propagate the reshape operations. With these patterns one can end up with operations like

%gemm = linalg.matmul ... -> tensor<?x?xf32>
%1 = tensor.expand_shape %gemm [[0, 1], [2]] : tensor<?x?xf32> into tensor<?x42x?xf32>
%2 = linalg.generic ... ins(... %1 ...) 

This affects fusing the linalg.matmul with linalg.generic later using tile + fuse. To address this issue a bunch of other patterns exist that try to handle this in one-off ways

  1. Fusion of reshape op with linalg.generic by linearization of the indexing maps (here). This was a very early attempt of doing fusion of reshape with elementwise operations. In general this leads to indexing maps that are not projected permutations that hamper subsequent optimizations (like tile + fuse). The only reason they are still around is for cases where unit-dimensions are folded (which happens a lot in ML models). These need to purged.
  2. Push tensor.expand_shape operations past linalg.generic operations (here. This has some of the right functionality, but is constrained in its application (and also does not work for ops with indexing semantics)

For the next iteration, I am planning to drop these two patterns above and replace this with a more general fusion of elementwise operations with reshape by collapsing dimensions. This pattern tries to collapse dimensions of the iteration space in the linalg.generic to fuse with the reshape, while maintaining indexing maps to be projected permutations.

The flow of fusion of elementwise operations is meant to be

  1. Run to fixed point fusion of linalg.genericlinalg.generic patterns along with the patterns that fuse tensor.collapse_shapelinalg.generic/ linalg.generictensor.expand_shape.
  2. Run to fixed point fusion of tensor.expand_shapelinalg.generic. Note that initially I dont plan to add the corollary pattern of fusing linalg.generictensor.collapse_shape. The rationale here is that the tensor.expand_shapelinalg.generic pushes reshapes “down”, which is better for subsequent tile + fuse. linalg.generictensor.expand_shape pushes reshape “up” which is detrimental to tile + fuse transformations later on.

I have an initial patch of the fusion by collapsing dimensions up (⚙ D119365 [mlir][Linalg] Add pattern for folding reshape by collapsing.) . I am going to prototype the effects and continue to iterate on this in context of IREE (where I have access to whole models and can collect statistics on how these approaches perform). If these work out, the patterns mentioned above that are superseded by the fusion by collapsing, will be deprecated. Ill post my findings on this RFC as and when I have them.

FYI: @gysit @nicolasvasilache


Should %0 be %gemm in your example?

Thanks. Fixed.

Follow up on this. I prototyped the efficacy of this approach when using it within IREE on a few different models that are tracked in IREE. All the changes I made are in (or will be soon) in MLIR. The only thing in IREE is the control function used to determine when the pattern should apply. The control function I used in IREE itself can be seen as part of this draft PR in IREE, but the simple reasoning I used for fusion is to fuse to introduce redundant computation in any of the patterns. That simple heuristic works well for elementwise fusion with fixed point iteration. There is a case to be made for fusing while doing redundant computation. That probably needs to be a deliberate choice and not as part of a fixed point iteration. The summary of my results on a few benchmarks are below

Model Num Generic ops before fusion Generic Ops after fusion (current) Generic ops after fusion (modified)
bert_encoderbase 2059 557 485
collatzbase 12 5 5
deeplabbase 106 87 59
edge_detectionbase 4 1 1
fragmentbase 11 1 1
fullyconnectedbase 12 3 3
mnistbase 10 8 6
mobilebertbase 1765 724 652
mobilenetv2base 103 84 56
mobilenetv3base 226 100 85
mobilessdbase 157 122 88
posenetbase 59 45 32
resnet50base 711 58 55
unidirectional_lstmbase 43 13 13

In all cases there is an improvement and in some cases quite a bit of improvement.

⚙ D123153 [mlir][Linalg] Allow collapsing subset of the reassociations when fusing by collapsing. and ⚙ D123236 [mlir][Linalg] Split `populateElementwiseOpsFusionPatterns`. are the changes that are mostly needed to enable the modified approach. After landing these, Ill send out a few more patches that will

  • Remove the patterns that are now defunct (the fusion by linearization patterns, and the push reshape op pattern)
  • Deprecate the -linalg-fuse-elementwise-ops pass. The efficacy of the patterns are heavily determined by the cost function (encoded in the control function) used. Such an opinionated pass in MLIR core is maybe ill-advised. It was always meant for testing, but has accumulated patterns over time. Ill deprecate this pass and move the tests in MLIR that use this pass to use test passes defined in test/lib/Dialect/Linalg.
1 Like