Using linalg to tile opaque operations not representable using affine maps

I have a use case where (1) there are operations that aren’t easily representable using structured ops constructs but would benefit from having the tiling / loop fusion capability of linalg. Additionally, (2) these operations can also map 1-1 to hardware primitives or external libraries, so preserving the patterns throughout transformations would also be required.

For example, I have a library that can compute softmax across a certain axis of a tensor. Even though I can represent the softmax denominator (sum of e^x for x is every element across an axis) using linalg, doing so would mean breaking down softmax into smaller ops and potentially losing the patterns after tiling / fusion, making it hard to map these ops back to a library call later. Ideally, I’d like to be able to make an operation’s implementation completely opaque to linalg, but still be able to leverage linalg’s tiling / fusion logic.

After some reading, I found that (1) is roughly what the TilingInterface aims to solve. Unfortunately, my understanding is that the TilingInterface work isn’t complete since only the interface itself is in mlir at the moment, and the linalg tiling algorithm doesn’t make use of that interface just yet. So I’d be greatly appreciated if someone can share more about the plans for TilingInterface.

Because TilingInterface isn’t yet an option, I was also looking for ways to satisfy both requirements using what’s currently available. I came across this discussion which suggests using a linalg.generic op with a call op (or an assert) in the region. This way, I don’t need to worry about generating bodies for ops that can’t be represented easily in linalg. However, if my library signature takes in only 2D tensors, how does this work when the linalg region only takes scalar values? I can make the type system happy by potentially casting the inputs to the region to the appropriate type. But I’m worried doing so would mess with the element-wise fusion algorithm where it mistakenly thinks that the library call can be fused with other ops when it fact it cannot (e.g. fusing a library call that takes in 2D tensors with another arithmetic op that only takes scalar).

I’d greatly appreciate any ideas and pointers. :slight_smile:

1 Like

I’ve been thinking about a similar issue, but because softmax has a reduction, yours seems to have some nuances that I didn’t understand from your description. How does your softmax library call operate on a tile? Can you put an example IR of what you’re trying to ideally achieve?

Thank you for your reply. Sorry for not being clear above, my softmax library can only take in 2D tensors and perform reduction across rows. For softmax operations with rank >= 3, there’s a restriction that the reduction is across the 2nd last dimension (so also across rows). In such cases, I want to leverage linalg to “tile” the outer batch dimensions and generate the loop that contains call to our 2D softmax and potentially fuse this loop with other ops that have the same batch dimensions.

So

%0 = linalg.init_tensor [8, 256, 512]: tensor<8x256x512xbf16> // input
%1 = linalg.init_tensor [8, 256, 512]: tensor<8x256x512xbf16> // output
%2 = linalg.softmax ins(%0: tensor<8x256x512xbf16>) outs(%1: tensor<8x256x512xbf16>)

will become (sorry I’m not using entirely correct IR form, but hopefully the idea is clear)

%0 = linalg.init_tensor [8, 256, 512]: tensor<8x256x512xbf16> // input
%1 = linalg.init_tensor [8, 256, 512]: tensor<8x256x512xbf16> // output
%2 = scf.for %arg0 = 0 to 8 step 1 iter_args(%arg1 = %0) -> (tensor<8x256x512xbf16>) {
	%3 = tensor.extract_slice %0[%arg0, 0, 0] [1, 256, 512] [1, 1, 1]: tensor<256x512xbf16>
	%4 = linalg.softmax(ins %3): tensor<256x512xbf16>
	%5 = tensor.insert_slice %4 into %1[%arg0, 0, 0] [1, 256, 512] [1, 1, 1]
	scf.yield %5
}

Ok, it is clear now. EDIT: in your example I think you want to extract_slice from %input (not from init_tensor) and insert_slice into the iterarg.

First, I don’t think adding a linalg.softmax would align with linalg’s philosophy (but someone please correct me if I’m wrong). So I changed it to x.softmax where x could be your dialect or one of the other higher-level dialects.

For fusion, let’s divide it into cases based on whether x.softmax is consuming another op or producing a tensor consumed by another linalg op.

If x.softmaxis the producer, I think you can get something working without modifying the linalg dialect or adding a tiling interface to your operation. You would add a rewrite that replaces tensor.extract_slice when the indices are only slicing parallel dims (relative to the softmax output) with a new copy of x.softmax with the new indices:

%1 = x.softmax %input -> tensor<...>
%2 = scf.for %arg0 = 0 to 8 step 1 iter_args(%arg1 = %0) -> (tensor<8x256x512xbf16>) {
	%3 = tensor.extract_slice %1[%arg0, 0, 0] [1, 256, 512] [1, 1, 1]: tensor<256x512xbf16>
	%4 = linalg.generic {...} ins(%3: tensor<...>) outs(....) -> tensor<...>
	%5 = tensor.insert_slice %4 into %1[%arg0, 0, 0] [1, 256, 512] [1, 1, 1]
	scf.yield %5
}

becomes

%1 = x.softmax %input -> tensor<...>
%2 = scf.for %arg0 = 0 to 8 step 1 iter_args(%arg1 = %0) -> (tensor<8x256x512xbf16>) {
	%3 = tensor.extract_slice %input[%arg0, 0, 0] [1, 256, 512] [1, 1, 1]: tensor<256x512xbf16>
         %4 = x.softmax %3 -> tensor<...>
	%5 = linalg.generic {...} ins(%4: tensor<...>) outs(....) -> tensor<...>
	%6 = tensor.insert_slice %5 into %arg1[%arg0, 0, 0] [1, 256, 512] [1, 1, 1]
	scf.yield %5
}

(then %1 is unused and gets removed during canonicalization)

If x.softmax is the consumer, it’s a bit more complicated, as you suggested. I know the tensor.pad operation implements the tiling interface, so you could follow that example, but I think this requires adding a case into the linalg tiling functions right now. Alternatively, the IREE folks seem to have a more robust tiling interface in their linalg_ext dialect here; maybe someone familiar with those extensions can comment on whether they will make it into upstream soon? I am also interested in this.

As a workaround, you could write your own routine to initiate the tiling and fusion if you do this rooted at your x.softmax operation. You would write code to tile x.softmax by using TilingInterface, then loop over the tensor operands and try to tile and pull in the producers of those operands. Theres a utility class TileLoopNest here that you could copy and modify to have a good starting point.

1 Like

These sound like the log-sum-exp type of layers I’ve been discussing with some folks separately. This notion of aggregate op that consists of a sequence of primitive linalg.generic ops is interesting but not yet supported.

There is ongoing work to upstream some of the IREE Tiling interface generalizations and to refactor tiling and fusion to work more generally (for instance, IREE has gather, FFT, Scan and other ops that tile IIRC).
These refactorings will also allow generalizations to other types of subset and tiling abstractions than just “strided rectangular arrays”.

So this is happening and expected to make good progress within the next few weeks (@MaheshRavishankar).

1 Like

Blockquote
These sound like the log-sum-exp type of layers I’ve been discussing with some folks separately. This notion of aggregate op that consists of a sequence of primitive linalg.generic ops is interesting but not yet supported.

Thank you for the updates! I’m glad that these ideas are being explored.

I and my team at Microsoft are very excited about the TilingInterface, please let me know if there’s anything we can help!

We’d be open to having a (public) meeting sometime to coordinate. The Thursday ODM times are easiest to schedule but also open to a one off, so long as we announce it. Lmk.

1 Like