I think this RFC is overall in the right direction, but I’m -1 on the change proposed. I’m going to list a mental model of the current state of vector dialect I have and based on that talk about the proposed change.
General vector dialect support for 0-d vectors
The problem with 0-d vector support in vector dialect isn’t 0-d vectors themselves, it’s how operations handle them as special cases. Generally, you can split vector dialect operations into 3 categories:
- Category 1: Operations defined on N-D vector space
- Category 2: Operations defined on 1-D vectors, extended to N-D vectors by
treating them as a “stack of 1-D vectors”.
- Category 3: Operations defined on 1-D vectors (usually mapping to llvm/spirv instrinsics)
Note that Category 2 operations are a restriction over Category 1, and Category 3 operations are a restriction on Category 2 operations. An operation defined as Category 1 can be used as a Category 2 or Category operation, but the other way around is not true.
Operations in Category 1 require 0-D vectors to be defined properly, since they work on a N-D space. Treating 0-D vectors as scalars for these operations is special casing and causes multiple bugs ([mlir][Vector] Add support for 0-d shapes in extract-shape_cast folder by Groverkss · Pull Request #116650 · llvm/llvm-project · GitHub, [mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts by Groverkss · Pull Request #113828 · llvm/llvm-project · GitHub, [mlir][Vector] Support 0-d vectors natively in TransferOpReduceRank by Groverkss · Pull Request #112907 · llvm/llvm-project · GitHub).
Operations in Category 2 simply do not support 0-D vectors by definition, and should not have 0-D vector support.
Operations in Category 3 should decide 0-D vector support or not based on what intrinsic they are targetting.
Treating a Category 2 operation as Category 1 generally leads to abstraction mismatch and bugs.
I’m going to give some examples of some operations and show how every operation can be grouped into these 3 categories, and that operations falling in Category 1 need 0-d vectors to be defined properly and operations in Category 2, when trying to behave like Category 1 operations cause problems.
vector.contract
vector.contract is a classic example of Category 1 operation, which is defined on a N-D vector space.
Computes the sum of products of vector elements along contracting dimension pairs from 2 vectors of rank M and N respectively, adds this intermediate result to the accumulator argument of rank K, and returns a vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims + num_batch_dims (see dimension type descriptions below)). For K = 0 (no free or batch dimensions), the accumulator and output are a scalar.
The operation needs to special case itself to scalars, because it’s a Category 1 operation, which needs 0-D vectors to be defined properly. This operation should support 0-D vectors and it will reduce special casing and bugs.
Example special casing in vector.contract:
The same logic applies for vector.multi_reduction, vector.transfer_read, vector.transfer_write, … these operations behave on a N-D vector space and require 0-D vectors to be defined properly.
vector.shuffle
vector.shuffle is a classic example of Category 2 operation (and the one shown in your original post having problems). From the docs:
The legality rules are:
- the two operands must have the same element type as the result
- Either, the two operands and the result must have the same rank and trailing
dimension sizes, viz. given two k-D operands v1 : <s_1 x s_2 x .. x s_k x
type> and v2 : <t_1 x t_2 x .. x t_k x type> we have s_i = t_i for all 1 < i
<= k
- Or, the two operands must be 0-D vectors and the result is a 1-D vector.
...
Examples:
%2 = vector.shuffle %a, %b[3, 2, 1, 0]
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
%3 = vector.shuffle %a, %b[0, 1]
: vector<f32>, vector<f32> ; yields vector<2xf32>
The op needs to be special cased for 0-D vectors, because it falls in Category 2 and is being forced to work with 0-D vectors. It should disallow 0-D vectors, which will reduce bugs and special casing for it.
vector.reduction
From the docs:
Note that these operations are restricted to 1-D vectors to remain close to the corresponding LLVM intrinsics:
LLVM Language Reference Manual — LLVM 20.0.0git documentation
vector.reduction is a classic Category 3 operation. It is meant to target a LLVM intrinsic and should not support 0-D vectors to match the corressponding LLVM intrinsic.
Another good example for such an operation is vector.matrix_multiply and vector.outer_product.
vector.extract / vector.insert
These operations are special. They were originally designed to work as Category 2 operations, but with the addition of 0-D vectors, were extended to be a mix of Category 1 and Category 2 operations. When they act as Category 2 operations (returning scalar instead of a 0-D vector by default, for example) Category 1 operations have to special case and this causes multiple bugs. There is an ambiguity in which Category these operations fall. My understanding is that this RFC is trying to remove this ambiguity, and make it fall into one of these categories, which is a good thing.
However, unlike vector.shuffle, vector.extract/vector.insert are very core to the vector dialect and act as glue for all operations in the dialect. The current RFC is trying to make these operations strictly Category 2, which is why I’m -1 on this RFC. It will mean more bugs and special casing for us on Category 1 operations. We will just have a different set of bugs and the problem will just get displaced elsewhere (as @nicolasvasilache mentions).
For example, when lowering vector.multi_reduction, the lowering has to special case if it sees a scalar accumulator because it extracted a lower dimensional vector:
The proper solution to have would be to convert vector.extract/vector.insert to be Category 1 operations, so it works well with every vector dialect operation. I’m going to propose how to do this below.
A better charter for vector dialect
The above definitions make it clear when a operation should support 0-D vectors
(Category 1), when it shouldn’t (Category 2) and when it depends on what it’s targetting (Category 3).
We should start by splitting the operations into which category they belong and it would make it much clearer how they need to be defined, and would eliminate most of the bugs that we face today. This would also bring out operations that are poorly defined, and would give us a chance to define them better.
Problems with vector.extract / vector.insert
I will take vector.extract as an example. The same argument applies to vector.insert. From the docs for vector.extract:
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at the proper position. Degenerates to an element type if n-k is zero.
The problem with this operation is that it is defined for Category 2 operations (stack of 1-d vectors, decompose to scalar if we go below 1-d vectors), which means that if used for Category 1 operations, it will cause special casing and usual bugs for missing that (as shown in examples above).
The problem is fixed if we split the operation into two:
- vector.extract: Extract a trailing (n-k)-D vector
- vector.extract_scalar: Extracting a scalar from a (N)-D vector
This leads to a consistent definition of the semantics of the operation over a
N-D vector space, and makes it a Category 1 operation (which means it can be used with Category 2 and Category 3 operations as well). This also makes it clear that when working with Category 2 operations and using vector.extract, you must explicitly use vector.extract_scalar, because 0-D vectors do not make sense for Category 2 operations.
Proposed action points
I’m proposing two things here, which are in spirit of this RFC, but a different solution to the problem:
- Put each operation in one of these two categories, and only allow 0-d vectors
for Category 1.
- Split vector.extract into two operations and make both of them Category 1.
(I wrote this based on discussions with @qed @kuhar @hanchung @MaheshRavishankar @manupak to understand vector dialect operations better and why we face 0-D vector bugs)