Background
In past, there existed several pass/transform regarding to fusion in MLIR, saying:
- Pass:
ElementwiseOpFusion. - Transform:
scf::tileConsumerAndFuseProducer.
However, there are also some limitation on these solutions. For example, ElementwiseOpFusion pass assumes GenericOp and only supports elementwise op. Besides, both of them execute fusion flow from consumer to producer, in which case, consumer is responsible for generate outer loops(a.k.a. iteration domain) and force producer adapt to them. As all you known, in most deep learning workloads, computation-sensitive op will take up most of time, like matmul or convolution(or so-called contraction). As the result, many developers prefer to implement them by hand-writing template with specific algorithms particularly involving how to tile on iteration domain and dispatch efficient kernel on multiple threads for either GPU or CPU. In this use-case, it seems unreasonable to tile a producer matmul within a common outer loops generated by consumer relu. On the contrary, it is better to fuse consumer relu into tiled producer matmul. Recently, another useful interface named scf::tileAndFuseConsumerOfSlice supported by community, as the counterpart of existed scf::tileAndFuseProducerOfSlice, makes this possible
Meanwhile, there is no combined pass/transform to fuse both producer and consumer of one contraction op within arbitrary nested loop structure up to now.
Taking classic MLP example consisting of pack+fill+matmul+broadcast+add+relu as example:
input weight
| /
| unpack
\ / bias
matmul /
\ broadcast
\ /
add
|
relu
where matmul has already been deeply tiled with complex nested loop structure before fusion, E.g.
%input = ....
%weight = tensor.pack (...)
%dest = linalg.fill (...)
%mm = scf.forall() {
%extract_slice_0 = tensor.extract_slice %input [...]
%extract_slice_1 = tensor.extract_slice %weight [...]
%extract_slice_2 = tensor.extract_slice %dest [...]
%0 = scf.for() {
%1 = scf.for() {
%extract_slice_3 = tensor.extract_slice %extract_slice_0 [...]
%extract_slice_4 = tensor.extract_slice %extract_slice_1 [...]
%extract_slice_5 = tensor.extract_slice %extract_slice_2 [...]
%tiled_mm = linalg.matmul ins (%extract_slice_3, %extract_slice_4) outs(%extract_slice_5)
%insert_slice_0 = tensor.insert_slice %tiled_mm into ...
yield %insert_slice_0
}
yield %1
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %0
}
}
%add = linalg.add ins(%mm, ...) outs(...)
Motivation
We intend to fuse all other ops(pack+fill+broadcast+add+relu) into this deeply tiled loops in one pass/transform.
Terminology
To better explain later algorithm, put forward some necessary terminology in advance.
pre-op and post-op fusion
pre-op fusion: fuse producer into tiled consumer
post-op fusion: fuse consumer into tiled producer
diffusion
It is a concept learnt from wiki. However, unlike traditional definition or well-known Diffusion Model in generative AI field, diffusion here is mainly used to describe the recursive process of pre-op/post-op fuse all ops surrounding one central op.
Algorithm
preOpFusion(OpOperand &tiledOperand)
- collect a chain of
extractSliceOpand theoriginalOperandoftiledOperand. In above example:extractSliceOpchain (from inner to outer):%extract_slice_3 = tensor.extract_slice %extract_slice_0 [...]→%extract_slice_0 = tensor.extract_slice %input [...]- originalOperand:
%input
- check boundary: producer of
originalOperandis tilable or non-contraction. If so, continue, otherwise, break. - filter all candidate by Linalg semantic and select best candidate slice op by cost model based on HAL(Hardware Analysis Layer). Currently, select the smallest one measured by the
sizemetric ofextractSliceOpby default. - call
tileAndFuseProducerOfSlice,
postOpFusion(OpResult tiledResult)
- collect a chain of
[parallel]insertSliceOpand theoriginalResultoftiledResult. In above example:[parallel]insertSliceOpchain (from inner to outer):%insert_slice_0 = tensor.insert_slice %tiled_mm into ...→tensor.parallel_insert_slice %0- originalResult:
%mm
- check boundary: consumer of
originalResultis tilable or non-contraction. If so, continue, otherwise, break. - filter all candidate by Linalg semantic and select best candidate slice op by cost model based on HAL(Hardware Analysis Layer). Currently, select the smallest one measured by the
sizemetric of[parallel]insertSliceOpby default. - call
tileAndFuseConsumerOfSlice.
diffusion(Operation *tiledOp)
- Prepare
deque<Operation*>with initial valuetiledOpandset<Operation*>to record which operation has already been tiled. - pop front of deque as new central op.
- execute
preOpFusion:
a. fuse producer(s) of operand of central op.
b. if success, push back new tiled op into deque and insert original op into set. - execute
postOpFusion:
a. fuse consumer(s) of result of central op.
b. if success, push back new tiled op into deque and insert original op into set. - repeat
step 2until it turns out empty.
Q & A
Q: Why collect chain of *_slice pair from inner to outer rather than multiple application of existing transform from outer to inner?
A: No matter tileAndFuseProducerOfSlice or tileAndFuseConsumerOfSlice are both invoked with *SliceOp, manually assigned in most transform usage before compiling time. But, it is possible(and certainly legal) for user to write a efficient template with multi-level *_slice, which is hard to manually assign which one in next round application especially for tileAndFuseConsumerOfSlice case. Furthermore, it is more friendly for future cost model(maybe a callback function) to select better candidate slice op during runtime stage.
Q: How to fuse reduction op?
A: In general, reduction fusion is blocked by tiling on reduction dimension. Benefit from the chain of *_slice mentioned above, we can filter out those valid candidates by Linalg semantic.
Q: Why not recursively call diffusion method but use deque?
A: There are two major reasons:
- When the amount of operation becomes extremely large, it may cause stack overflow.
- From the traversal view: recursive way is DFS style while
dequeappears more like BFS
style, which is safe for domination relationship and more reasonable for resultant IR.
Q: What is difference between diffusion and other potentially simple combination of tileAndFuseProducerOfSlice and tileAndFuseConsumerOfSlice?
A: In scf::tileConsumerAndFuseProducer, there indeed exists a deque to recursively fuse producer and producer of producer. But, it could not fuse consumer of producer at the same time, saying:
producer2
| \
producer1 consumer2
|
tiledConsumer
, where consumer2 may not be fused in the end.
matmul
\ broadcast
\ /
add
|
relu
Similarly, in sub-pattern of above MLP example, although add+relu can be fused via multiple tileAndFuseConsumerOfSlice, it may also omit pre-op fusion for broadcast.
Moreover, diffusion can even deal with following complex topology:
contraction
\ op1 op3
\ / \ /
add op2
| |
relu op4
\ /
op5
NOTE that: this topology also explains why we prefer BFS-like as traversal favor, which ensure op5 must be visited after relu and op4.
Q: What if one op lay between two contraction ops? saying:
matmul
...
Op
|
matmul
Is this Op post-op fused with first matmul or pre-op fused into second one?
A: It is decided by which matmul execute diffusion at first. IMO, diffusion should be executed based on topological order. In another word, post-op fusion expects higher priority,
Open
It is optional to make this to be either a transform or one pass. I am not sure which one is better. Looking forward to your suggestion,
I am also glad to hear comments or any other questions from you. Thanks a lot!