Hi all,
This RFC proposes changing the way bit-width dependent flattening/linearization works, or removing it entirely.
Question I’d most like answered
Are there still users of the bit-width logic? Either for linearization, or for the flattening of transfer_read and transfer_write? Because if not, maybe we can jump straight to proposal 1 (remove).
History of bit-width dependent flattening
This is just what I have uncovered! The logic for bit-width dependent flattening was introduced in this PR. Shortly thereafter the logic was extended to the closely related linearization pass in this PR.
There were comments left by reviewers/authors in those PRs about future work to have partial unrolling/flattening, such as this one. These seem closely related to this RFC, specifically proposal 3 below.
I refactored the logic for linearization here and separated the bit-width specific logic from the core linearization. That separation was possible because linearization uses type conversion (unlike most of Vector).
What is bit-width dependent flattening
The idea is that ops with operands/results with inner dimensions that are larger than some threshold are left unchanged. For example if the bit-width threshold is greater than 8 in:
%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>
then the op is left unchanged, otherwise it is linearized to a rank-1 shuffle. The logic for this is here.
For transfer_read and transfer_write in the original PR, it is essentially the same: if the vector that is read or written has a large enough inner dimension, it is not flattened at all. Logic here.
Unintuitive behavior
The current behavior is ok, but could be better in my opinion. To me, it is strange that we can have 2 ops that would be linearized to exactly the same thing, but only one is linearized. Example:
%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>
might be unchanged (if the threshold is less than 8 bits), while
%0 = vector.extract %arg0[1]: vector<32x8x1xi1> from vector<2x32x8x1xi1>
is linearized/flattened to
%1 = vector.shuffle [...] [indices] : vector<512xi1>, vector<512xi1>
with surrounding shape_cast ops changing the rank. %0 would be linearized to exactly the same shuffle, if the threshold was larger than 8.
Proposal 1
Remove bit-width specific logic entirely. First need to check who relies on it. PR: [mlir][vector] Remove bit-width logic tests by newling · Pull Request #143007 · llvm/llvm-project · GitHub
Proposal 2
Make bit-width specific logic depend on the total number of elements in the vector, not just the inner-dimension.
Proposal 3
Legalize ops to their ‘nearest’ legal form. In the case above
%0 = vector.extract %arg0[1]: vector<32x8x1xi1> from vector<2x32x8x1xi1>
would be converted to
%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>
In other words, this proposal is to reduce the rank of operands/results by incrementally collapsing the inner 2 dimensions, until the bit-width threshold is met (or the rank is 1).
A similar logic could be applied to transfer_read/transfer_write flattening. I think this is similar to what is suggested by @hanchung here.
Proposal 4
Leave the logic as it is. This is ok, although blocks me from implementing some improvements to linearization (rough description here)
Thank you for reading!