I’m trying to work out how to interface between fixed-length and scalable contexts within MLIR. I believe there are a few of us working on this, and now is probably the best time to tackle it. First, let me start with a description of the problem and its motivation.
Quick refresh on scalable vectors
(Skip if you’re already familiar with scalable vectors)
A scalable vector holds a number of elements that’s a multiple of a base size, and that multiple is a runtime constant: the vector scale (a VPU design parameter).
For instance, if a vector<4xf32>
can hold 4 floating point elements, a scalable vector<4xf32>
can hold 4, 8, 12, … up to a limit defined by the ISA. In MLIR, we represent such a vector as vector<[4]xf32>
, meaning that those 4 dimensions could have a multiplicity, or vector scale, greater than 1.
Scalable vectors create a new software-side concept: vector-length agnosticism. An operation is vector-length agnostic when it works for any possible vector scale. Likewise, a code is vector-length agnostic when it works for any possible vector scale.
Conceptually, if a fixed-length vector addition loop is something like:
for (unsigned i = 0; i < num_data_elements; i += 4) {
v4f32 a = load_vector(data, i); // Load <4 x f32>
vector b = a + a; // Element-wise add of <4 x f32> to <4 x f32>
store_vector(result, i, b); // Store <4 x f32>
}
A vector-length agnostic equivalent is something like:
for (unsigned i = 0; i < num_data_elements; i+= 4 * vector_scale) {
sv4f32 a = load_scalable_vector(data, i); // Load <vscale x 4 x f32>
sv4f32 b = a + a; // Element-wise add of <vscale x 4 x f32> vectors
store_scalable_vector(result, i, b); // Store <vscale x 4 x f32>
}
Where vector_scale
is the runtime constant defining the scale of our vectors. When the operations are more complex than a simple element-wise vector addition (think of horizontal reductions), a useful conceptual model is to understand the vector-length agnostic operation as an implicit loop over contiguous vectors of the base size.
For instance, a common dot product VLA operation:
for (unsigned i = 0; i < num_data_elements; i += 4 * vector_scale) {
sv4f32 a = load_scalable_vector(a_data, i); // Load <vscale x 4 x f32>
sv4f32 b = load_scalable_vector(b_data, i); // Load <vscale x 4 x f32>
sv4f32 c = vla_scalable_dot_product(a, b); // Perform vscale x <4 x f32> by <4 x f32> dot products
store_scalable_vector(c_data, i, cv); // Store <vscale x 4 x f32>
}
Is a common vector-length agnostic way to perform a 4-wise dot product. The reduction doesn’t happen across the whole length of the physical vector, but across segments of the base vector length (4 x f32). As if the code were:
for (unsigned i = 0; i < num_data_elements; i += 4 * vector_scale) {
for (unsigned j = i; j < i + 4*vector_scale; j += 4) {
v4f32 a = load_vector(a_data, j); // Load <4 x f32>
v4f32 b = load_vector(b_data, j); // Load <4 x f32>
v4f32 c;
c[0] = dot_product(a, b); // Perform one <4 x f32> by <4 x f32> dot product
store_vector(c_data, j, c); // Store <4 x f32>
}
}
The opposite of a vector-length agnostic (VLA) operation or code, is a vector-length specific (VLS) operation or code. Notice that not all operations on scalable vectors are vector-length agnostic. For instance, shuffle ops or extract ops that operate over the total length of the vector (not a base segment), would not be VLA, even if they’re being performed on scalable vectors. Likewise, a non-segmented horizontal reduction on a scalable vector (e.g.: an operation that computes de addition of all the values in a scalable vector and returns a single scalar) is not VLA either.
Vector-length specific code on scalable architectures
Why?
For performance reasons, even if our vector architecture is scalable, we may want to generate VLS code. We can assume a specific vector scale, and generate code with a known target vector length. In principle, as long as we make sure that the host architecture that will run the code supports the target vector length, we can generate correct VLS code that runs on a scalable architecture.
There are a couple of different ways to go about this:
- Generate fixed-length vectors of the appropriate size (the assumed
vscale
times the base length), for instance:vector<16xf32>
for a scalable architecture of 512 bits and a base length of 128 bits (avscale
of 4), and use function attributes that force the instruction selector to pick scalable instructions (if there is a fixed-length alternative). - Generate scalable vectors (e.g.:
vector<[4]xf32>
in the example above), but assume a fixed size in your loop steps.
Option two comes without any interfacing issues, but forces you to use generic VLA vector instructions in the IR and prevents VLS code generation strategies. From a performance point of view, the first option is the most interesting one, but it comes with some interfacing issues.
The problem
For basic arithmetic instructions, using scalable instructions or fixed-length instructions is entirely up to the instruction selector, and function attributes like vscale_range
can force the selection of scalable instructions.
For complex, architecture-specific operations (dot products, inner product, outer products, …) we need to generate intrinsics that are only defined for scalable operands. If our code has been generated with fixed-length vectors but we want to rewrite a higher level vector instruction with a hw-specific scalable vector intrinsic, we need a way to cast the incoming fixed-length operands into equivalent scalable vectors, and back from scalable to fixed-length vectors for the result.
For instance:
#gemm_trait = {
indexing_maps = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
],
iterator_types = ["parallel", "parallel", "reduction"]
}
func.func gemm(%a: vector<2x8xi8>, %b: vector<8x2xi8>, %acc: vector<2x2xi32>)
-> vector<2x2xi32> {
%0 = vector.contract %a, %b, %c #gemm_trait : vector<2x8xi8>, vector<8x2xi8> into vector<2x2xi32>
return %0 : vector<2x2xi32>
}
Since it shares its semantics, the vector contraction can be rewritten by the arm_sve.smmla
ArmSVE dialect instruction, but while the fixed-length vector.contract
takes fixed-length vectors, arm_sve.smmla
takes scalable vectors. We need a mechanism to perform that cast.
Solutions in LLVM IR
In LLVM IR, there are a couple of experimental intrinsics, llvm.experimental.vector.insert and llvm.experimental.vector.extract, that allow the insertion/extraction of fixed-length vectors into/from scalable vectors. This way, you can pack fixed-length vectors into scalable vectors, call a scalable vector function or intrinsic, and unpack the result back to fixed-length vectors. Like so:
define void @fl2svmuladd(float *arg0, float *arg1, float *arg2) {
; Fixed-length world
%0 = bitcast float* %arg0 to <8 x float>*
%1 = bitcast float* %arg1 to <8 x float>*
%2 = bitcast float* %arg2 to <8 x float>*
%3 = load <8 x float>, <8 x float>* %0
%4 = load <8 x float>, <8 x float>* %1
%5 = load <8 x float>, <8 x float>* %2
; Fixed-length to scalable
%4 = call <vscale x 4 x float> @llvm.experimental.vector.insert.nxv4f32.v8f32(<vscale x 4 x float> undef, <8 x float> %3, i64 0)
%5 = call <vscale x 4 x float> @llvm.experimental.vector.insert.nxv4f32.v8f32(<vscale x 4 x float> undef, <8 x float> %4, i64 0)
%6 = call <vscale x 4 x float> @llvm.experimental.vector.insert.nxv4f32.v8f32(<vscale x 4 x float> undef, <8 x float> %5, i64 0)
; Scalable world
%7 = call <vscale x 4 x float> @llvm.fmuladd.nxv4f32(<vscale x 4 x float> %4, <vscale x 4 x float> %5, <vscale x 4 x float> %6)
; Scalable to fixed-length
%8 = call <8 x float> @llvm.experimental.vector.extract.v8f32.nxv4f32(<vscale x 4 x float> %7, i64 0)
; Back in fixed-length world
store <8 x float> %8, <8 x float>* %2
ret void
}
Interfacing between fixed-length and scalable vectors in MLIR
The question I’d like to ask is, what’s the best way to address this issue of mixed fixed-length and scalable vectors within MLIR?
My assumptions:
- We want to do this only when we’re interfacing fixed-length vector code in the
Vector
dialect with one of the intrinsics in the scalable hw-specific dialects (ArmSVE or RVV for now).- This will be a common occurrence when using fixed-length vectorization strategies but we want to target a complex non-SIMD SIMD vector operation (like gemm acceleration ops) in a scalable architecture.
- For these complex vector operations, the fixed-length to scalable conversion will often be accompanied by a shape conversion
- Since LLVM IR only admits rank-1 vectors but these operations often have multi-rank semantics (like the
sdot
, who operates in base segments, orsmmla
, that operates on tiled data), the switch from fixed-length to scalable will be preceded by a flattening of the vector, and the switch from scalable to fixed-length will be succeeded by a reshape from a linear to a multi-rank vector.
- Since LLVM IR only admits rank-1 vectors but these operations often have multi-rank semantics (like the
Based on these assumptions, although the obvious answer to this question might be to extend vector.insert
and vector.extract
to accept mixed scalability, and lower those cases to the LLVM IR
intrinsics, I believe the operation that makes the most sense for this process is vector.shape_cast
.
From a high-level point of view, even if we implement it with these insert/extract constructs, going from fixed-length to scalable is more of a “shape cast” type of operation. Since I anticipate the need for a shape cast anyway, I think that is the operation we should modify for this.
If we go back to the arm_sve.smmla
example above, the resulting conversion would be:
func.func gemm(%a: vector<2x8xi8>, %b: vector<8x2xi8>, %acc: vector<2x2xi32>)
-> vector<2x2xi32> {
%sa = vector.shape_cast %a : vector<2x8xi8> to vector<[16]xi8>
%sb = vector.shape_cast %b : vector<8x2xi8> to vector<[16]xi8>
%sc = vector.shape_cast %c : vector<2x2xi32> to vector<[4]xi32>
%0 = arm_sve.smmla %sc, %sa, %sb : vector<[16]xi8> to vector<[4]xi32>
%res = vector.shape_cast %0 : vector<[4]xi32> to vector<2x2xi32>
return %res : vector<2x2xi32>
}
Lowering shape cast between scalable and fixed-length vectors
The follow-up question is, how do we lower these mixed scalability shape casts?
The process of casting a scalable vector to a fixed-length vector consists in defining the vscale
constant at compile time. That is, going from something like: vector<[4]xf32>
to vector<8xf32>
, where we are forcing our scalable architecture to be a 256b vector architecture (vscale = 2
). Likewise, if we have a vector<8xf32>
, a 256b vector, and we want to map it to a scalable architecture with 128b of base vector length, we can trivially do so by dividing the length by the vscale
: vector<[4]xf32>
.
My proposal is that we can lower these “trivial” shape casts, in which the fixed-length size is a multiple of the base size of the scalable vector, to experimental.vector.insert/extract
in the conversion from Vector Dialect to LLVM Dialect. E.g.:
%sv = vector.shape_cast %in : vector<8xf32> to vector<[4]xf32>
%flv = vector.shape_cast %sv : vector<[4]xf32> to vector<8xf32>
Can be trivially lowered to:
%loc = arith.constant dense<0> : vector<[4]xf32>
%slv = llvm.intr.vector.insert %in, %loc[0] : vector<8xf32> into vector<[4]xf32>
%flv = llvm.intr.vector.insert %slv[0] : vector<8xf32> from vector<[4]xf32>
For the non-trivial case, I propose to decompose the lowering into the flattening shape cast + a trivial fixed-length/scalable shape change (and vice versa).
For instance, going back to the arm_sve.smmla
, the first lowering step would be:
func.func gemm(%a: vector<2x8xi8>, %b: vector<8x2xi8>, %acc: vector<2x2xi32>)
-> vector<2x2xi32> {
%0 = vector.shape_cast %a : vector<2x8xi8> to vector<16xi8>
%sa = vector.shape_cast %0 : vector<16xi8> to vector<[16]xi8>
%1 = vector.shape_cats %b : vector<8x2xi8> to vector<16xi8>
%sb = vector.shape_cast %1 : vector<16xi8> to vector<[16]xi8>
%2 = vector.shape_cast %c : vector<2x2xi32> to vector<4xi32>
%sc = vector.shape_cast %2 : vector<4xi32> to vector<[4]xi32>
%3 = arm_sve.smmla %sc, %sa, %sb : vector<[16]xi8> to vector<[4]xi32>
%4 = vector.shape_cast %3 : vector<[4]xi32> to vector<4xi32>
%5 = vector.shape_cast %4 : vector<4xi32> to vector<2x2xi32>
return %res : vector<2x2xi32>
}
From there, the fixed-length to fixed-length vector.shape_cast
operations are lowered as usual, and the trivial fixed-length to scalable and vice versa, are lowered to llvm.intr.vector.insert/extract
.
For a slightly more complex example, if we take one of the operands of arm_sve.smmla
for the vscale = 4
, that is, a 512b scalable architecture with a base vector size of 128b, the lowering process of one of the operands would be:
Initial:
%sv = vector.shape_cast %a : vector<4x2x8xi8> to vector<[16]xi8>
First lowering step (vector.shape_cast
→ vector.shape_cast
):
%tv = vector.shape_cast %a : vector<4x2x8xi8> to vector<64xi8>
%sv = vector.shape_cast %tv : vector<64xi8> to vector<[16]xi8>
Second lowering step (vector.shape_cast
→ vector.insert/extract
):
%cst = arith.constant dense<0> : vector<64xi8>
%0 = vector.extract %a[0, 0, 0] : vector<4x2x8xi8>
%1 = vector.insert %0, %cst [0] : i8 into vector<64xi8>
%2 = vector.extract %a[0, 0, 1] : vector<4x2x8xi8>
%3 = vector.insert %2, %1 [1] : i8 into vector<64xi8>
%4 = vector.extract %a[0, 0, 2] : vector<4x2x8xi8>
[...]
%126 = vector.extract %a[3, 1, 7] : vector<4x2x8xi8>
%127 = vector.insert %126, %125 [63] : i8 into vector<64xi8>
%sv = vector.shape_cast %127 : vector<64xi8> to vector<[16]xi8>
Last lowering step + canonicalization (vector.insert/extract
& vector.shape_cast
→ LLVM):
%0 = llvm.mlir.constant(63 : i64) : i64
%1 = llvm.mlir.constant(62 : i64) : i64
%2 = llvm.mlir.constant(61 : i64) : i64
[...]
%63 = llvm.mlir.constant(0 : i64) : i64
%cst = arith.constant dense<0> : vector<64xi8>
%64 = builtin.unrealized_conversion_cast %arg0 : vector<4x2x8xi8> to !llvm.array<4 x array<2 x vector<8xi8>>>
%65 = llvm.extractvalue %64[0, 0] : !llvm.array<4 x array<2 x vector<8xi8>>>
%66 = llvm.extractelement %65[%63 : i64] : vector<8xi8>
%67 = llvm.insertelement %66, %cst[%63 : i64] : vector<64xi8>
[...]
%253 = llvm.insertelement %252, %250[%1 : i64] : vector<64xi8>
%254 = llvm.extractvalue %64[3, 1] : !llvm.array<4 x array<2 x vector<8xi8>>>
%255 = llvm.extractelement %254[%56 : i64] : vector<8xi8>
%256 = llvm.insertelement %255, %253[%0 : i64] : vector<64xi8>
%tmp = arith.constant dense<0> : vector<[16]xi8>
%sv = llvm.intr.experimental.vector.insert %256, %tmp[0] : vector<64xi8> into vector<[16]xi8>
I’ve already submitted a patch that adds the experimental.vector.insert/extract
intrinsics to the LLVM Dialect (D127100), and I’m working on another patch that will extend vector.shape_cast
in the way I’ve described above.
Several people and I have already discussed this, publicly and in private, and I’d like to hear opinions from the community.