Extending `tileConsumerAndFuseProducer` to handle more patterns

To get transform.structured.fuse to work with tosa.concat, I took a stab at implementing a TilingInterface for it whose generateResultTileValue would generate something like this for tosa.concat(%in1, %in2) (roughly speaking):

%cmp1 = arith.cmp slt %offset, %in1_size
%out = scf.if %cmp1 {
	%slice1 = tensor.extract_slice %in1[%offsets][%sizes]
	scf.yield %slice1
} else {
    %offset2 = arith.subi %offset, %in1_size
	%slice2 = tensor.extract_slice %in2[%offset2][%size]
	scf.yield %slice2
}

I tried returning the two scf.yield operations in the tiledOps field of TilingResult and the final scf.if’s result in the final tiledValues. However, this prevents the producers of %in1 and %in2 from getting further fused into the resulting tiled version of the concat op. This seems to be because the assumption is that the last value in tiledOps contains an operation whose operands directly refer to some tensor.extract_slice ops. For the tosa.concat example above, that becomes difficult.
To get this to work, I had to make some modifications to the tileConsumerAndFuseProducer code here. However, I am wondering if there’s some other trick I am missing in how to implement the TilingInterface which will allow the fusion algorithm to work as is with the above structure.

To get the above tiling work with transform.structured.fuse, I had to make the following modifications to tileConsumerAndFuseProducers:

--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1045,7 +1045,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   };

   std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tiledAndFusedOps.back(), candidates);
+  for (auto *op : tiledAndFusedOps)
+    addCandidateSlices(op, candidates);
+
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
     // Traverse the slices in BFS fashion.
@@ -1087,7 +1089,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
       tiledAndFusedOps.insert(tiledAndFusedOp);
-      addCandidateSlices(tiledAndFusedOp, candidates);
+      for (auto *op : fusedResult->tiledOps)
+        addCandidateSlices(op, candidates);
     }
   }

While this does not seem to cause any failures in the existing lit test-suite, @qed mentioned that it might be a potential foot-gun.

@MaheshRavishankar mentioned that we might need to extend SCFTilingResult to also contain a list of the tensor.extract_slices created by tiling.

Actually looking at this a little bit more, the best way forward I see is to add a field to TilingResult to return the slices of operands created during tiling that could be used for fusion. Then you could use the slices to build the worklist here and here to drive the operands that can be fused.

I think that is a good direction to take. Makes things less ambiguous, and avoids having to try to find the slices in the implementation of the tiling algorithm. This would require downstream users to change their implementation of getTiledImplementation and generateResultTileValue to return the slices to fuse along, but the changes should be fairly straight-forward to add.

Thanks for the hints. I’ll take a stab at this. The two places you mention do match up with my understanding.

@raghavanr had another suggestion which might alleviate some of the rework needed by other clients of the TilingInterface as it stands now. Instead of modifying TilingResult directly, we extend TilingInterface with a new method which allows a client to return the tensor.extract_slices from the tiledOps field.

In the default case, we would just iterate over the operands of the tiled op. However, in my case, I could walk the region of the scf.if to return the contained tensor.extract_slice ops.

I am not really familiar with the tradeoff between changing the APIs in a “backwards incompatible” way to judge the merits of the two approaches.

Hi @MaheshRavishankar ,

(sorry I was busy with some other stuff for a bit)

Acting on your suggestion, I added a new field extracedSliceOps to TilingResult (which also then needed to be carried over to scf::SCFTilingResult and scf::SCFFuseProducerOfSliceResult in order to work with tileConsumerAndFuseProducers). I also modified the current implementations of TilingInterface to populate the extract slice ops. I then modified the two locations you pointed me to above to use these to add to the worklist (instead of fetching them from the operands of the replacement value).

This worked for the most part except in the case when we hit this code path:

    if (yieldReplacement) {
      if (failed(yieldReplacementForFusedProducer(
              rewriter, candidateSliceOp, fusedResult.value(), loops))) {
        return rewriter.notifyMatchFailure(
            fusableProducer.getOwner(), "failed to replacement value for this "
                                        "oepration from within the tiled loop");
      }
      origValToResultNumber[fusableProducer] =
          loops.front()->getNumResults() - 1;
    }

yieldReplacementForFusedProducer changes the argument of the fusedResult.tiledAndFusedProducer to refer to a newly created loop iteration argument. The original code then added the operands of the now modified op to the worklist. However, with the new approach, the extractedSliceOps are not updated. I could update yieldReplacementForFusedProducer to also modify the extractedSliceOps of the passed in SCFFuseProducerOfSliceResult. However, it is not at all clear how and whether we should even do it this way.

Any advice would be appreciated!

Thank you for following up on the suggestion and working through it. I am having trouble understand the exact issue. Particularly

is unclear to me. The method doesnt take fusedResult by reference to modify it. I am happy to help. I am also happy to get on a call if we want to have a high-bandwidth discussion and walk through the issue as we step through the code. If you give me a WIP PR and a repro, that would work too.

Thanks Mahesh!

After staring at yieldReplacementForFusedProducer for a while more, I did manage to make it work. I have a WIP PR here:

Please note that this is definitely a WIP (I haven’t modified the tensor dialect implementations of TilingInterface yet in this PR).

The lit test which exposed the issue with this is mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

I have a comment in the PR trying to explain the problem (and the potential solution).

I’d be happy to hop on a call as well if that works for you. Please let me know a good time.

Out of curiosity, how does this align with your desire to make the tiling interface as op-agnostic as possible? It looks like it would enshrine the fact of operand slicing happening during tiling.

(Not opposing the change, it somewhat aligns with my suggestion from another discussion to separate out data tiling from iteration space tiling: we would now explicitly return the result of data tiling even though it is still performed simultaneously with the iteration space tiling).

And how exactly requiring clients to implement an additional method would be less work for them?

I’m also concerned that this further hardcodes the need for extract_slice operations to be involved in tiling. This goes in the wrong direction regarding generality of the interface. And having to traverse IR to rediscover the information that was readily available sounds contrary to the core design principles of MLIR. IIUC, Mahesh suggests above that operand slices (e.g., values) are returned, not the slicing operations.

And how exactly requiring clients to implement an additional method would be less work for them?

I can let @raghavanr confirm my understanding, but the idea is that existing implementations of TilingInterface would not have to do anything because the default implementation of the new method does what is being done currently.

This feels a bit moot given your other comment about trying to move away from needing extract_slice ops in tiling.

IIUC, Mahesh suggests above that operand slices (e.g., values) are returned, not the slicing operations.

Thanks for clarifying! This is not how I understood his suggestion. I had assumed that he wanted clients to explicitly return the list of tensor.extract_slice ops which they need to construct as part of generating the tiled op (this is how I implemented the PR above).

@MaheshRavishankar could you please clarify?

Please note that I have now implemented the TilingInterface for the tensor dialect ops as well, so I feel the PR is ready for review (if my understanding of your suggestion is actually correct)

This is exactly what I had proposed. The default implementation of this new method would handle all the existing scenarios (clearly that has been good enough until now), so nothing needs to be changed for them, while any op that needs a more custom implementation can add their own (which will help the case that Srinath had reported above).

There are two parts… For tiling itself you dont need it, but this will help tile and fuse. Tile and fuse fundamentally relies on slicing (fusion is essentially converting op -> slice to slice -> op. Getting these slices from the tiling implementation allows better control where the fusion points. So it allows you to fuse along particular operands if the operation + tiling implementation allows it, but the tiling implementation can return a empty set of slices. That would mean you cant fuse with this operation, but it can still be tiled.

A related question is (and something that Quinn pointed out on reviews), the fusion is indexing on tensor.extract_slice and tensor.insert_slice. The intent here has been we start with this, and generalize the tile and fuse implementation to use some more general interface for slices. That hasn’t materialized yet in tree.

Thanks!

That sounds acceptable as long as we clearly document where fusion is intentionally rendered impossible as opposed to hasn’t been implemented yet.

We already have the interfaces though: llvm-project/mlir/include/mlir/Interfaces/SubsetOpInterface.td at 9e63632b3274dc1b20502b569e79a311977e0a97 · llvm/llvm-project · GitHub – so it would be preferable to use them in new code and not introduce more tech debt?