Given that there’s growing interest in handling scalable vectorization in MLIR (see: [RFC] Add RISC-V Vector Extension (RVV) Dialect, and the older proposal: [RFC] Vector Dialects: Neon and SVE), I propose adding support for scalable vectors to the built-in vector type.
Motivation
The two main issues with the current approach used by ArmSVE and the proposed RVV are:
- Handling scalable vectors as low level dialect types creates redundancies between different hw dialects, and between those dialects and common ones; e.g.: rewriting Arith operations and their lowering down to LLVM IR on each of the hw-specific dialects, just to get them to accept scalable vector operands
- Adapting vectorization and lowering passes to work with scalable vectors requires either making those common dialects dependent on very low level dialects, or duplicating those passes within each of the hw-specific dialects.
These issues make working with scalable vectors extremely cumbersome, and bound to cause a lot of maintainability issues either in scalable dialects, higher level dialects, or both. If the intention is for MLIR to be able to work on all kinds of hardware, we need a way to indicate scalability in vectors across all dialects working with them.
What is a scalable vector type?
A scalable vector type is a SIMD type that stores a number of elements that’s a multiple of a fixed size. The multiplicity of the vector is unknown at compile time, but it’s a constant at runtime. E.g., something like:
%0 = arith.addf %a, %b : vector<4xf32>
indicates the addition of two vectors of 4 single precision floating point elements. On the other hand, if we represent a scalable vector using double angle brackets, something like:
%0 = arith.addf %a, %b : vector<<4xf32>>
indicates the addition of two vectors that contain a multiple of 4 single precision floating point elements. The multiple is a runtime constant represented by vector_scale
, and the value is determined by the hardware implementation.
Proposed solution
As a first step, I suggest adding a flag within the VectorType class to represent scalability. This way, everything everywhere else works as it is, and a lot of what’s supported in hw dialects with scalable vector types, work automatically as part of dialects like Vector (load/store) and Arith (arithmetic & comparison operations).
As a proposed syntax for scalable vectors I’ve chosen the double angle brackets, mostly because it is friendlier to the MLIR parser than the syntax adopted by LLVM IR, but also it makes sense if you squint; in most cases, you can think of a scalable vector as an array of vectors, or a vector of vectors. That said, I will happily change it if there’s a better alternative.
With the proposed change, assuming the input parameters are correctly sized, this is how a vector addition function implemented as a vector-length agnostic (VLA) loop would look like:
func @vector_add(%src_a: memref<?xf32>, %src_b: memref<?xf32>, %dst: memref<?xf32>, %size: index) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%vs = vector_scale : index
%step = arith.muli %c4, %vs : index
scf.for %i = %c0 to %size step %step {
%0 = vector.load %src_a[%i] : memref<?xf32>, vector<<4xf32>>
%1 = vector.load %src_b[%i] : memref<?xf32>, vector<<4xf32>>
%2 = arith.addf %0, %1 : vector<<4xf32>>
vector.store %2, %dst[%i] : memref<?xf32>, vector<<4xf32>>
}
return
}
And using the options -convert-vector-to-llvm -convert-scf-to-std -convert-std-to-llvm -reconcile-unrealized-casts
, the following LLVM Dialect can be obtained:
llvm.func @vector_add(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr<f32>, %arg6: !llvm.ptr<f32>, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: !llvm.ptr<f32>, %arg11: !llvm.ptr<f32>, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: i64) {
%0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%12 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%13 = llvm.insertvalue %arg10, %12[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%14 = llvm.insertvalue %arg11, %13[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%15 = llvm.insertvalue %arg12, %14[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%16 = llvm.insertvalue %arg13, %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%17 = llvm.insertvalue %arg14, %16[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%18 = llvm.mlir.constant(0 : index) : i64
%19 = llvm.mlir.constant(4 : index) : i64
%20 = "llvm.intr.vscale"() : () -> i64
%21 = llvm.mul %20, %19 : i64
llvm.br ^bb1(%18 : i64)
^bb1(%22: i64): // 2 preds: ^bb0, ^bb2
%23 = llvm.icmp "slt" %22, %arg15 : i64
llvm.cond_br %23, ^bb2, ^bb3
^bb2: // pred: ^bb1
%24 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%25 = llvm.getelementptr %24[%22] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
%26 = llvm.bitcast %25 : !llvm.ptr<f32> to !llvm.ptr<vector<<4xf32>>>
%27 = llvm.load %26 {alignment = 4 : i64} : !llvm.ptr<vector<<4xf32>>>
%28 = llvm.extractvalue %11[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%29 = llvm.getelementptr %28[%22] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
%30 = llvm.bitcast %29 : !llvm.ptr<f32> to !llvm.ptr<vector<<4xf32>>>
%31 = llvm.load %30 {alignment = 4 : i64} : !llvm.ptr<vector<<4xf32>>>
%32 = llvm.fadd %27, %31 : vector<<4xf32>>
%33 = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
%34 = llvm.getelementptr %33[%22] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
%35 = llvm.bitcast %34 : !llvm.ptr<f32> to !llvm.ptr<vector<<4xf32>>>
llvm.store %32, %35 {alignment = 4 : i64} : !llvm.ptr<vector<<4xf32>>>
%36 = llvm.add %22, %21 : i64
llvm.br ^bb1(%36 : i64)
^bb3: // pred: ^bb1
llvm.return
}
Which in turn can be translated to LLVM IR that compiles to valid code for scalable architectures.
I’ve uploaded a patch with the proposed changes to provide some ground for the discussion:
[mlir][RFC] Make scalable vector type a built-in type
Open Issues
Better implementation of the scalable vector type.
In LLVM IR, the type is handled with a sensible class hierarchy, and I believe a similar implementation would also be preferable for MLIR. I’ve started with this to keep the size and the reach of the patch contained, since I expect this one to be a longer discussion and chasing VectorType uses throughout all the dialects for a long period of time can be a significant source of work. That said, once there’s an agreement on how to do this, I’m happy to skip this patch altogether and go straight for an alternative, if that is preferred.
Thank you in advance for you feedback on this topic.