Background
In past, there existed several pass/transform regarding to fusion in MLIR, saying:
- Pass:
ElementwiseOpFusion
. - Transform:
scf::tileConsumerAndFuseProducer
.
However, there are also some limitation on these solutions. For example, ElementwiseOpFusion
pass assumes GenericOp
and only supports elementwise op. Besides, both of them execute fusion flow from consumer to producer, in which case, consumer is responsible for generate outer loops(a.k.a. iteration domain) and force producer adapt to them. As all you known, in most deep learning workloads, computation-sensitive op will take up most of time, like matmul
or convolution
(or so-called contraction). As the result, many developers prefer to implement them by hand-writing template with specific algorithms particularly involving how to tile on iteration domain and dispatch efficient kernel on multiple threads for either GPU or CPU. In this use-case, it seems unreasonable to tile a producer matmul
within a common outer loops generated by consumer relu
. On the contrary, it is better to fuse consumer relu
into tiled producer matmul
. Recently, another useful interface named scf::tileAndFuseConsumerOfSlice
supported by community, as the counterpart of existed scf::tileAndFuseProducerOfSlice
, makes this possible
Meanwhile, there is no combined pass/transform to fuse both producer and consumer of one contraction op within arbitrary nested loop structure up to now.
Taking classic MLP example consisting of pack+fill+matmul+broadcast+add+relu
as example:
input weight
| /
| unpack
\ / bias
matmul /
\ broadcast
\ /
add
|
relu
where matmul
has already been deeply tiled with complex nested loop structure before fusion, E.g.
%input = ....
%weight = tensor.pack (...)
%dest = linalg.fill (...)
%mm = scf.forall() {
%extract_slice_0 = tensor.extract_slice %input [...]
%extract_slice_1 = tensor.extract_slice %weight [...]
%extract_slice_2 = tensor.extract_slice %dest [...]
%0 = scf.for() {
%1 = scf.for() {
%extract_slice_3 = tensor.extract_slice %extract_slice_0 [...]
%extract_slice_4 = tensor.extract_slice %extract_slice_1 [...]
%extract_slice_5 = tensor.extract_slice %extract_slice_2 [...]
%tiled_mm = linalg.matmul ins (%extract_slice_3, %extract_slice_4) outs(%extract_slice_5)
%insert_slice_0 = tensor.insert_slice %tiled_mm into ...
yield %insert_slice_0
}
yield %1
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %0
}
}
%add = linalg.add ins(%mm, ...) outs(...)
Motivation
We intend to fuse all other ops(pack+fill+broadcast+add+relu
) into this deeply tiled loops in one pass/transform.
Terminology
To better explain later algorithm, put forward some necessary terminology in advance.
pre-op and post-op fusion
pre-op fusion: fuse producer into tiled consumer
post-op fusion: fuse consumer into tiled producer
diffusion
It is a concept learnt from wiki. However, unlike traditional definition or well-known Diffusion Model
in generative AI field, diffusion
here is mainly used to describe the recursive process of pre-op/post-op fuse all ops surrounding one central op.
Algorithm
preOpFusion(OpOperand &tiledOperand)
- collect a chain of
extractSliceOp
and theoriginalOperand
oftiledOperand
. In above example:extractSliceOp
chain (from inner to outer):%extract_slice_3 = tensor.extract_slice %extract_slice_0 [...]
→%extract_slice_0 = tensor.extract_slice %input [...]
- originalOperand:
%input
- check boundary: producer of
originalOperand
is tilable or non-contraction. If so, continue, otherwise, break. - filter all candidate by Linalg semantic and select best candidate slice op by cost model based on HAL(Hardware Analysis Layer). Currently, select the smallest one measured by the
size
metric ofextractSliceOp
by default. - call
tileAndFuseProducerOfSlice
,
postOpFusion(OpResult tiledResult)
- collect a chain of
[parallel]insertSliceOp
and theoriginalResult
oftiledResult
. In above example:[parallel]insertSliceOp
chain (from inner to outer):%insert_slice_0 = tensor.insert_slice %tiled_mm into ...
→tensor.parallel_insert_slice %0
- originalResult:
%mm
- check boundary: consumer of
originalResult
is tilable or non-contraction. If so, continue, otherwise, break. - filter all candidate by Linalg semantic and select best candidate slice op by cost model based on HAL(Hardware Analysis Layer). Currently, select the smallest one measured by the
size
metric of[parallel]insertSliceOp
by default. - call
tileAndFuseConsumerOfSlice
.
diffusion(Operation *tiledOp)
- Prepare
deque<Operation*>
with initial valuetiledOp
andset<Operation*>
to record which operation has already been tiled. - pop front of deque as new central op.
- execute
preOpFusion
:
a. fuse producer(s) of operand of central op.
b. if success, push back new tiled op into deque and insert original op into set. - execute
postOpFusion
:
a. fuse consumer(s) of result of central op.
b. if success, push back new tiled op into deque and insert original op into set. - repeat
step 2
until it turns out empty.
Q & A
Q: Why collect chain of *_slice
pair from inner to outer rather than multiple application of existing transform from outer to inner?
A: No matter tileAndFuseProducerOfSlice
or tileAndFuseConsumerOfSlice
are both invoked with *SliceOp
, manually assigned in most transform usage before compiling time. But, it is possible(and certainly legal) for user to write a efficient template with multi-level *_slice
, which is hard to manually assign which one in next round application especially for tileAndFuseConsumerOfSlice
case. Furthermore, it is more friendly for future cost model(maybe a callback function) to select better candidate slice op during runtime stage.
Q: How to fuse reduction op?
A: In general, reduction fusion is blocked by tiling on reduction dimension. Benefit from the chain of *_slice
mentioned above, we can filter out those valid candidates by Linalg semantic.
Q: Why not recursively call diffusion
method but use deque
?
A: There are two major reasons:
- When the amount of operation becomes extremely large, it may cause stack overflow.
- From the traversal view: recursive way is DFS style while
deque
appears more like BFS
style, which is safe for domination relationship and more reasonable for resultant IR.
Q: What is difference between diffusion
and other potentially simple combination of tileAndFuseProducerOfSlice
and tileAndFuseConsumerOfSlice
?
A: In scf::tileConsumerAndFuseProducer
, there indeed exists a deque
to recursively fuse producer and producer of producer. But, it could not fuse consumer of producer at the same time, saying:
producer2
| \
producer1 consumer2
|
tiledConsumer
, where consumer2
may not be fused in the end.
matmul
\ broadcast
\ /
add
|
relu
Similarly, in sub-pattern of above MLP example, although add+relu
can be fused via multiple tileAndFuseConsumerOfSlice
, it may also omit pre-op fusion for broadcast
.
Moreover, diffusion
can even deal with following complex topology:
contraction
\ op1 op3
\ / \ /
add op2
| |
relu op4
\ /
op5
NOTE that: this topology also explains why we prefer BFS-like as traversal favor, which ensure op5
must be visited after relu
and op4
.
Q: What if one op lay between two contraction ops? saying:
matmul
...
Op
|
matmul
Is this Op
post-op fused with first matmul or pre-op fused into second one?
A: It is decided by which matmul
execute diffusion
at first. IMO, diffusion
should be executed based on topological order. In another word, post-op fusion expects higher priority,
Open
It is optional to make this to be either a transform or one pass. I am not sure which one is better. Looking forward to your suggestion,
I am also glad to hear comments or any other questions from you. Thanks a lot!