I have a use case where (1) there are operations that aren’t easily representable using structured ops constructs but would benefit from having the tiling / loop fusion capability of linalg. Additionally, (2) these operations can also map 1-1 to hardware primitives or external libraries, so preserving the patterns throughout transformations would also be required.
For example, I have a library that can compute softmax across a certain axis of a tensor. Even though I can represent the softmax denominator (sum of e^x for x is every element across an axis) using linalg, doing so would mean breaking down softmax into smaller ops and potentially losing the patterns after tiling / fusion, making it hard to map these ops back to a library call later. Ideally, I’d like to be able to make an operation’s implementation completely opaque to linalg, but still be able to leverage linalg’s tiling / fusion logic.
After some reading, I found that (1) is roughly what the TilingInterface
aims to solve. Unfortunately, my understanding is that the TilingInterface
work isn’t complete since only the interface itself is in mlir at the moment, and the linalg tiling algorithm doesn’t make use of that interface just yet. So I’d be greatly appreciated if someone can share more about the plans for TilingInterface
.
Because TilingInterface
isn’t yet an option, I was also looking for ways to satisfy both requirements using what’s currently available. I came across this discussion which suggests using a linalg.generic op with a call op (or an assert) in the region. This way, I don’t need to worry about generating bodies for ops that can’t be represented easily in linalg. However, if my library signature takes in only 2D tensors, how does this work when the linalg region only takes scalar values? I can make the type system happy by potentially casting the inputs to the region to the appropriate type. But I’m worried doing so would mess with the element-wise fusion algorithm where it mistakenly thinks that the library call can be fused with other ops when it fact it cannot (e.g. fusing a library call that takes in 2D tensors with another arithmetic op that only takes scalar).
I’d greatly appreciate any ideas and pointers.