TLDR;
This RFC proposes an addition to the tiling interface by introducing the generateOperandTileValue
method, allowing for fusion with consumers.
Motivation
Tiling interface only support fuse producer with interface method ‘generateResultTileValue’ currently, but do not have interface like ‘generateOperandTileValue’ to fuse consumers. which means we cannot fuse consumer with tiling interface.
Consider the following input CFG:
A
/ \
B C
We cannot achieve this behavior with the current tile and fuse infrastructure: merging these three ops into the same for loop through tiling and fusion.
Previous Discussions
Proposal
Initially, We will add a new method to the tiling interface:
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the operand.
Generates the IR that computes the tile of a result of the
operation. The `offsets` and `sizes` describe the tile of
the output required.
- `offsets` provides the offset of the tile in the coordinate system
of the original iteration space, i.e., if an iteration space
dimension had non-zero offset, it must be included in the offset
provided here (as opposed to zero-based offset "relative" to the
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"generateOperandTileValue",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>
Next, we will implement this method for all ops with the tiling interface. For example, for linalg ops, we can obtain the operand-to-domain mapping using the “getMatchingIndexingMap” method and then call “getTiledImplementation” to realize this function.
What about the fusion?
In the current implementation, we achieve loop fusion mostly through FuseIntoContainingOp, We use it as an example to illustrate. FuseIntoContainingOp pass finds tensor.extract_slice using the producer result within the loop to obtain result tile info, and then calling the generateResultTileValue method to generate the op within the loop.
For example, consider the following input:
A = "foo.op"(I)
scf.forall (...) {
A_slice = tensor.extract_slice(A) // Anchor
B = "foo.op1"(A_slice)
}
"foo.op2"(A)
After fuse foo.op
into scf.forall
, we have result ir:
A = "foo.op"(I)
scf.forall (...) {
I_slice = tensor.extract_slice(I)
A_local = "foo.op0"(I_slice)
B = "foo.op1"(A_local)
}
"foo.op2"(A)
When we support fusion with consumers, we can also find tensor.parallel_insert_slice corresponding to the output within the loop to obtain operand tile info. Then, we can call the generateOperandTileValue method to generate the op within the loop.
For example:
%1 = scf.forall ... {
%0 = ...
scf.forall.in_parallel {
tensor.parallel_insert_slice %0 into %o[...][...][...] : ... // Anchor
}
}
%2 = "foo.op0"(%1)
"foo.op1"(%1)
"foo.op2"(%2)
After fuse foo.op0
into scf.forall
, we have result ir:
%1:2 = scf.forall ... {
%0 = ...
%2_local = "foo.op0"(%0)
scf.forall.in_parallel {
tensor.parallel_insert_slice %0 into %o[...][...][...] : ... // Anchor
tensor.parallel_insert_slice %2_local into %o1[...][...][...] : ... // Anchor
}
}
"foo.op1"(%1#0)
"foo.op2"(%1#1)
It is worth noting that we will not remove the existing return values of scf.forall to prevent any other usage. If there are no other uses, these should be removed in the future by DCE.
Other issues
- As disscussed in Any plan to fuse consumers?.
@rengolin mentioned complex bufferization logic and I didn’t understand why. The inplace logic should be specified by the op’s bufferization interface. What does this have to do with tiling itself? - TBD
The code for this RFC has not been implemented yet, and we hope to receive more suggestions from everyone before starting the implementation. Any suggestions and questions will be welcomed, and I appreciate everyone’s attention in advance.