[RFC][mlir][vector] Modify/remove bit-width specific flattening/linearization

Hi all,

This RFC proposes changing the way bit-width dependent flattening/linearization works, or removing it entirely.

Question I’d most like answered

Are there still users of the bit-width logic? Either for linearization, or for the flattening of transfer_read and transfer_write? Because if not, maybe we can jump straight to proposal 1 (remove).

History of bit-width dependent flattening

This is just what I have uncovered! The logic for bit-width dependent flattening was introduced in this PR. Shortly thereafter the logic was extended to the closely related linearization pass in this PR.

There were comments left by reviewers/authors in those PRs about future work to have partial unrolling/flattening, such as this one. These seem closely related to this RFC, specifically proposal 3 below.

I refactored the logic for linearization here and separated the bit-width specific logic from the core linearization. That separation was possible because linearization uses type conversion (unlike most of Vector).

What is bit-width dependent flattening

The idea is that ops with operands/results with inner dimensions that are larger than some threshold are left unchanged. For example if the bit-width threshold is greater than 8 in:

%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>

then the op is left unchanged, otherwise it is linearized to a rank-1 shuffle. The logic for this is here.

For transfer_read and transfer_write in the original PR, it is essentially the same: if the vector that is read or written has a large enough inner dimension, it is not flattened at all. Logic here.

Unintuitive behavior

The current behavior is ok, but could be better in my opinion. To me, it is strange that we can have 2 ops that would be linearized to exactly the same thing, but only one is linearized. Example:

%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>

might be unchanged (if the threshold is less than 8 bits), while

%0 = vector.extract %arg0[1]: vector<32x8x1xi1> from vector<2x32x8x1xi1>

is linearized/flattened to

%1 = vector.shuffle [...] [indices] : vector<512xi1>, vector<512xi1>

with surrounding shape_cast ops changing the rank. %0 would be linearized to exactly the same shuffle, if the threshold was larger than 8.

Proposal 1

Remove bit-width specific logic entirely. First need to check who relies on it. PR: [mlir][vector] Remove bit-width logic tests by newling · Pull Request #143007 · llvm/llvm-project · GitHub

Proposal 2

Make bit-width specific logic depend on the total number of elements in the vector, not just the inner-dimension.

Proposal 3

Legalize ops to their ‘nearest’ legal form. In the case above

%0 = vector.extract %arg0[1]: vector<32x8x1xi1> from vector<2x32x8x1xi1>

would be converted to

%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>

In other words, this proposal is to reduce the rank of operands/results by incrementally collapsing the inner 2 dimensions, until the bit-width threshold is met (or the rank is 1).

A similar logic could be applied to transfer_read/transfer_write flattening. I think this is similar to what is suggested by @hanchung here.

Proposal 4

Leave the logic as it is. This is ok, although blocks me from implementing some improvements to linearization (rough description here)

Thank you for reading!

CC @dcaballe who implemented the initial logic.

Thanks for writing this up and for providing all the context - that’s very helpful!

First, an apology: in your PR, I commented in support of removing the bit-width-related logic. However, with the additional context (and my memory refreshed), I realise this was all part of a broader design we never fully implemented. So I’d like to retract that support for now.

This seems to stem from the presence of the trailing unit dimension. In other words, we’re hitting the flattening logic before certain shape normalization or pre-processing steps have occurred. We might indeed be missing a few canonicalization patterns here, but that’s the first direction I’d explore.

In fact, this is essentially what you describe in Proposal 3:

That said, I don’t fully follow the phrase “inner 2 dimensions” - in your example, it looks like only a single dimension is being collapsed. But that’s just a nit.

This last point makes me wonder: shouldn’t the bit-width logic be a no-op when targetVectorBitWidth is not provided? If not, maybe that’s something we should fix directly?


My overall thinking is that there is a lot of nuance and unfinished work here. Once we have a clear mental model, we should be able to converge towards something that unblocks you.

Let me know what you think - and thanks again for putting this together! I think it’s a good time to revisit this area.

-Andrzej

So I’d like to retract that support for now .

That’s totally fine. I agree that the decision should be based on this broader scope. Writing this RFC was a good exercise, it forced me to reveal more of the context.

That said, I don’t fully follow the phrase “inner 2 dimensions” - in your example, it looks like only a single dimension is being collapsed. But that’s just a nit.

I think we mean the same thing, by collapse the inner N dimensions (of a rank M+N thing) I mean flatten N dimensions into a single dimension. So go from rank M+N to rank M+1. I’ll be more explicit next time to avoid this ambiguity!

This seems to stem from the presence of the trailing unit dimension .

I’d rather say that the unit dimension reveals the extreme case. To me the unintuitive behavior persists when the trailing dim isn’t 1. Consider 3 types T1) 1600xi1 T2) 100x16xi1 T3) 100x4x4xi1.

With a bit-threshold of 8, transfer_read operations are converted as follows: T1 -> T1 and T2 -> T2 and T3 -> T1. Which is better, T2 or T1 ? If T1 is better, then why not T2 -> T1? If T2 is better, then why not T3 -> T2?

The current algorithm is:
If the rank-1 form is better, convert to the rank-1 form.

The gradual lowering (proposal 3) algorithm:
While the rank N-1 form is better, convert from rank N to rank N-1 form.

The latter seems better to me. I can’t provide a strong argument for why.

This last point makes me wonder: shouldn’t the bit-width logic be a no-op when targetVectorBitWidth is not provided? If not, maybe that’s something we should fix directly?

It is, yes. The issue I have is that if I change the patterns to flatten more gradually, the bit-width related linearization tests fail - linearization gets stuck at T2 and doesn’t go to T1. Just to clarify, I am not blocked from making progress if we don’t change the logic (proposal 4). It would just mean less compact code, because I need to retain the ‘direct path’ patterns.

Another thought I’ve had is that flattening of transfer_read/transfer_write (code) is a kind of linearization, and maybe Vector could be simplified by moving it to linearization (with options to choose subsets of patterns).

Hey, thanks for bringing this up! Sorry, I had missed the PR!

Could you elaborate a bit more on the blocking aspect of this? The bitwidth threshold is an optional knob. I’m not sure I follow after reading the PR description.

Unit dimensions have been problematic for many transformations all over the compiler and, unfortunately, dealing with them gracefully in each and every pass is complex and expensive. If this is a major source of problems for this transformation, documenting that removing redundant unit dimensions is a pre-requisite to this transformation should help set expectations.

Yes, proposal 3 seems like the right way to go. Vector unrolling and linearization are used in different ways by different projects but one of the common goals is to make sure we optimize the use of physical vector registers, which is what my comment was about at the time.

Thanks for your feedback, Diego! It all makes sense to me.

Could you elaborate a bit more on the blocking aspect of this? The bitwidth threshold is an optional knob. I’m not sure I follow after reading the PR description.

This is not really blocking me. It just means I can’t implement an orthogonal change as compactly as I’d like, without changing the behavior users of the knob observe.

At an abstract level, there is an existing pattern to go from state A to state C. I want to add new patterns, one A->B, and B->C. With these I should in theory be able to remove A->C. But I can’t, because with the current bitwidth logic, A might be illegal and B legal, so we get ‘stuck’ at B.

This is an interesting example - thanks for sharing! It highlights an assumption we make in the current infrastructure that’s probably not documented explicitly.

Basically, we assume that we start with a shape like T3 (e.g., 100x4x4xi1), which doesn’t have enough elements in the innermost dimension to fit a “good” vector width. So we collapse dimensions, increasing the size of the trailing dimension, ideally producing something closer to a target-native vector width.

For example, collapsing 100x4x4xi32 to 100x16xi32 makes sense if we think in terms of generating LLVM IR with 16-element vector types. That’s a size that LLVM is comfortable with - even if the hardware doesn’t directly support 16xi32, LLVM can lower that to 4xi32 chunks, which are supported on typical hardware.

In that sense, something like 1600xi1 is already considered “good enough” and doesn’t need to be further reshaped. Of course, LLVM will still need to break it into legal vector sizes (and will likely spend some time doing so). Ideally, the input shapes we generate should already be near the native vector size - say 4, 8, 16 elements - so we minimize that overhead.

Presumably, in your “universe” (not CPU and not LLVM?), T1 is not great and you would rather have T2?

Does that match your understanding of how this is supposed to work? And do you think that explains the behaviour you’re seeing?

Presumably, in your “universe” (not CPU and not LLVM?), T1 is not great and you would rather have T2 ?

As an end point, for me T1 is good. But how we get there from T3 to T1 is interfering with some changes I’d like to make. It’s not a serious issue. At this point, I propose we make this RFC more about about finding a well defined and flexible approach for everyone, than about my non-blocking problem!

Does that match your understanding of how this is supposed to work?

Yes I think so. Based on your comments, am I right to think you’re leaning towards proposal 3? Ignoring possible churn, code drift, waste of effort, etc.

Thanks for discussing this @banach-space and @dcaballe. I was quite hoping to hear from folks who use the bit-width logic, I’ll give it some more time. If I don’t hear more in the next days/weeks I’ll probably start experimenting with proposal 3, to see if it leads anywhere sensible. That’s not to say I’m taking proposal 3 as accepted, I just think some code might clarify the idea.

That’s a very healthy approach - thank you for being flexible!

Proposal 3 makes the most sense to me at the moment as well, but I do share your concern about:

But like you said, it’s hard to reason about these risks without more concrete examples :slight_smile: