Authors: @nicolasvasilache (Google), @javiersetoain (Arm)
Hello everyone,
we have been discussing with Arm recently and interest has grown towards supporting concrete mixed-precision and scalable operations provided by Neon and SVE. We have built some prototypes to connect the dots and propose to upstream these dialects to make compilation for mobile targets a reality in MLIR. Since these dialects are being developed by, are of interest to the same people, have similar HW targets and sit at similar levels of abstraction, they are proposed for inclusion in a joint RFC.
What is the overall goal of the dialect?
The Vector Dialect document discusses the MLIR vector abstractions and their tradeoffs. The Hardware Vector Ops (HWV) level is provisioned to allow representing non-portable operations and have them interoperate with portable vector operations and MLIR codegen. This proposal is for adding new Targets/Neon
and Targets/SVE
dialects that would directly model target-specific intrinsics.
This proposal allows 3 concrete things:
- make it possible to represent mixed-precision and scalable workloads in MLIR all the way to execution on specific HW.
- connect MLIR codegen to these specific abstractions and explore transformations targeting mixed-precision and scalable workloads.
- further explore tradeoffs and interop. between HW-specific and HW-agnostic vector abstractions in MLIR.
What is the first implementation milestone?
The first implementation milestone adds the LLVM-level dialects and implements some basic operations (say *mull
, *dot
, *mmla
). Like other intrinsics in the LLVM dialect, they are lightweight and are represented in their custom op form (i.e. no special parsing / printing behavior).
The Tablegen specification resembles:
def LLVM_aarch64_neon_smull :
LLVMNeon_IntrBinaryOverloadedOp<"smull">, Arguments<(ins LLVM_Type, LLVM_Type)>;
and
def LLVM_aarch64_sve_smmla :
LLVMSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_sve_sdot :
LLVMSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
The LLVM dialect form resembles:
llvm.func @neon_smull(%arg0: !llvm.vec<8 x i8>, %arg1: !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16> {
%0 = "llvm_neon.smull"(%arg0, %arg1) : (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16>
llvm.return %0 : !llvm.vec<8 x i16>
}
and
llvm.func @sve_sdot(%arg0: !llvm.vec<? x 16 x i8>, %arg1: !llvm.vec<? x 16 x i8>, %arg2: !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32> {
%0 = "llvm_sve.sdot"(%arg2, %arg0, %arg1) : (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>) -> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
llvm.func @sve_add_2(%arg0: !llvm.vec<? x 16 x i8>, %arg1: !llvm.vec<? x 16 x i8>, %arg2: !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32> {
%0 = "llvm_sve.smmla"(%arg2, %arg0, %arg1) : (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>) -> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
They have a counterpart operation specified on MLIR 1-D vector types for the purpose of type checking, interop. with portable vector ops as well as codegen + progressive lowering. The Tablegen specification is:
def SMullOp : Neon_Op<"smull", [NoSideEffect,
AllTypesMatch<["a", "b"]>,
TypesMatchWith<
"res has same vector shape and element bitwidth scaled by 2 as a",
"a", "res", "$_self.cast<VectorType>().scaleElementBitwidth(2)">]> {
let summary = "smull op";
let description = [{
/* Doc to extract from Neon ISA manual */
}];
// Supports either:
// (vector<8xi8>, vector<8xi8>) -> (vector<8xi16>)
// (vector<4xi16>, vector<4xi16>) -> (vector<4xi32>)
// (vector<2xi32>, vector<2xi32>) -> (vector<2xi64>)
let arguments = (ins VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$a,
VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$b);
let results = (outs VectorOfLengthAndType<[8, 4, 2], [I16, I32, I64]>:$res);
let assemblyFormat =
"$a `,` $b attr-dict `:` type($a) `to` type($res)";
}
def SmmlaOp : SVE_Op<"smmla",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
let summary = "Matrix-matrix mutiply and accumulate op";
let description = [{
The smmla op is an SVE specific op that can lower to the proper LLVMSVE
operation: `llvm.aarch64.sve.smmla` instruction.
/* Doc to extract from SVE ISA manual */
}];
// Supports vector<16xi8>.
let arguments = (ins
ScalableVectorOf<[I32]>:$acc,
ScalableVectorOf<[I8]>:$src1,
ScalableVectorOf<[I8]>:$src2
);
let results = (outs ScalableVectorOf<[I32]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)";
}
The MLIR form for Neon ops composes with existing retargetable vector
ops and resembles:
func @neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
-> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
%0 = neon.smull %a, %b : vector<8xi8> to vector<8xi16>
%00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
vector<8xi16> to vector<4xi16>
%1 = neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
%11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
vector<4xi32> to vector<2xi32>
%2 = neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
}
On the other hand the SVE dialect defines its own scalable vector type and does not yet compose with existing retargetable vector
ops and resembles:
func @sve_smmla(%a: !sve.vector<16xi8>, %b: !sve.vector<16xi8>, %c: !sve.vector<4xi32>) -> !sve.vector<4xi32>
{
%0 = sve.smmla %c, %a, %b : !sve.vector<16xi8> -> !sve.vector<4xi32>
return %0 : !sve.vector<4xi32>
}
This is a good starting point to support mixed target-agnostic and target-specific lowering.
Scalable Vector Type Representation
Scalable vectors mirror the syntax of standard vectors. If
vector<4xf32>
is a fixed-length vector containing 4 single precision floating point elements,
!sve.vector<4xf32>
is a scalable vector containing a HW-dependent symbolic multiple of 4 single precision floating point elements.
A scalable vector like the one above converts to a scalable vector in the LLVM Dialect:
!llvm.vec<? x 4 x f32>
Which in turn translates into LLVM IR:
<vscale x 4 x float>
Connection to control-flow (e.g. loops) and memory operations follows VLA-style programmaming which is akin to parametric tiling. The SVE dialect introduces an intrinsic to represent the scale and lower to the appropriate LLVM instruction:
sve.vscale : index
The best way to lower from Vector Dialect to Scalable Vector-based code remains an open question. We expect that the availability of these lower-level vector dialects will help experimentation to determine the best way to lower from high-level fixed-length vector code down to low-level scalable vector code.
In the first incantation, the scalable vector type is confined to the SVE dialect which encompasses ARM-specific instructions. Once we gain more experience with end-to-end codegen of scalable vectors, we expect to separate the type representation so that the infrastructure becomes generally reusable across HW.
How does it fit into the MLIR dialect ecosystem?
Connection: how does it connect to the existing dialects in a compilation pipeline(s)?
The Neon and SVE dialects sit at the HWV layer in the following diagram (extracted from the Vector Dialect document):
The compilation pipeline will start by allowing naive codegen of higher-level ops (e.g. Linalg, loops) that carry the payload information of these new ops.
More elaborate compilation, involving notably scalar-to-vector conversions in the presence of these new ops and scalable vector types, is the subject of ongoing investigations.
Consolidation: is there already a dialect with a similar goal or matching abstractions; if so, can it be improved instead of adding a new one?
There is no current support for ops with mixed-precision and ops on scalable vectors in MLIR core. Additionally, the design of MLIR vector abstractions provisions for target-specific dialects to capture HW-specific variations.
Reuse: how does it generalize to similar but slightly different use-cases?
In the future, we expect the design of the vector dialect and transformations to be influenced by the Neon and SVE dialects and evolve towards more generality than what is allowed today.
Still, even with future evolutions of the vector dialect, target-specific abstractions that allow finer-grained control than what can be achieved with action-at-a-distance through compiler flags, will be useful. High-performance libraries are expected to be simpler to design and interoperate when the proper abstractions are available at the right level in the IR.
Who are the future contributors/maintainers beyond those who propose the dialect?
It is expected that these dialects will be a generally useful abstraction layer to the MLIR community. While it is the goal that the community itself will contribute to extending and maintaining the abstractions, for the foreseeable future we expect Google and Arm to contribute and maintain these dialects.
Current Status
Prototypes for the Neon dialect and the SVE dialect are available and proposed for upstreaming.
They already allow the generation of the expected assembly.
Neon
From the MLIR input:
func @neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
-> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
%0 = neon.smull %a, %b : vector<8xi8> to vector<8xi16>
%00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
vector<8xi16> to vector<4xi16>
%1 = neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
%11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
vector<4xi32> to vector<2xi32>
%2 = neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
}
Convert to LLVM Dialect with mlir-opt -convert-neon-to-llvm
:
module {
llvm.func @neon_smull(%arg0: !llvm.vec<8 x i8>, %arg1: !llvm.vec<8 x i8>) -> !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> {
%0 = "llvm_neon.smull"(%arg0, %arg1) : (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16>
%1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : !llvm.vec<8 x i16>, !llvm.vec<8 x i16>
%2 = "llvm_neon.smull"(%1, %1) : (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32>
%3 = llvm.shufflevector %2, %2 [1, 2] : !llvm.vec<4 x i32>, !llvm.vec<4 x i32>
%4 = "llvm_neon.smull"(%3, %3) : (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64>
%5 = llvm.mlir.undef : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
%6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
%7 = llvm.insertvalue %2, %6[1] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
%8 = llvm.insertvalue %4, %7[2] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
llvm.return %8 : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
}
}
Translate to LLVM IR with mlir-translate -neon-mlir-to-llvmir
:
define { <8 x i16>, <4 x i32>, <2 x i64> } @neon_smull(<8 x i8> %0, <8 x i8> %1) !dbg !3 {
%3 = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %0, <8 x i8> %1), !dbg !7
%4 = shufflevector <8 x i16> %3, <8 x i16> %3, <4 x i32> <i32 3, i32 4, i32 5, i32 6>, !dbg !9
%5 = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %4, <4 x i16> %4), !dbg !10
%6 = shufflevector <4 x i32> %5, <4 x i32> %5, <2 x i32> <i32 1, i32 2>, !dbg !11
%7 = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %6, <2 x i32> %6), !dbg !12
%8 = insertvalue { <8 x i16>, <4 x i32>, <2 x i64> } undef, <8 x i16> %3, 0, !dbg !13
%9 = insertvalue { <8 x i16>, <4 x i32>, <2 x i64> } %8, <4 x i32> %5, 1, !dbg !14
%10 = insertvalue { <8 x i16>, <4 x i32>, <2 x i64> } %9, <2 x i64> %7, 2, !dbg !15
ret { <8 x i16>, <4 x i32>, <2 x i64> } %10, !dbg !16
}
; Function Attrs: nounwind readnone
declare <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8>, <8 x i8>) #0
; Function Attrs: nounwind readnone
declare <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16>, <4 x i16>) #0
; Function Attrs: nounwind readnone
declare <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32>, <2 x i32>) #0
Compile to Aarch64 assembly with llc -O3 -mtriple=aarch64-none-linux-gnu -mattr=+neon
:
neon_smull: // @neon_smull
.Lfunc_begin0:
.file 1 "/usr/local/google/home/ntv/github/llvm-project/build/<stdin>"
.loc 1 2 0 // <stdin>:2:0
.cfi_startproc
// %bb.0:
.loc 1 3 10 prologue_end // <stdin>:3:10
smull v0.8h, v0.8b, v1.8b
.loc 1 4 10 // <stdin>:4:10
ext v1.16b, v0.16b, v0.16b, #6
.loc 1 5 10 // <stdin>:5:10
smull v1.4s, v1.4h, v1.4h
.loc 1 6 10 // <stdin>:6:10
ext v2.16b, v1.16b, v1.16b, #4
.loc 1 7 10 // <stdin>:7:10
smull v2.2d, v2.2s, v2.2s
.loc 1 12 5 // <stdin>:12:5
ret
SVE
From the MLIR input:
func @sve_sdot(%a: !sve.vector<16xi8>, %b: !sve.vector<16xi8>, %c: !sve.vector<4xi32>) -> !sve.vector<4xi32>
{
%0 = sve.sdot %c, %a, %b : !sve.vector<16xi8> -> !sve.vector<4xi32>
return %0 : !sve.vector<4xi32>
}
func @sve_smmla(%a: !sve.vector<16xi8>, %b: !sve.vector<16xi8>, %c: !sve.vector<4xi32>) -> !sve.vector<4xi32>
{
%0 = sve.smmla %c, %a, %b : !sve.vector<16xi8> -> !sve.vector<4xi32>
return %0 : !sve.vector<4xi32>
}
func @sve_udot(%a: !sve.vector<16xui8>, %b: !sve.vector<16xui8>, %c: !sve.vector<4xui32>) -> !sve.vector<4xui32>
{
%0 = sve.udot %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32>
return %0 : !sve.vector<4xui32>
}
func @sve_ummla(%a: !sve.vector<16xui8>, %b: !sve.vector<16xui8>, %c: !sve.vector<4xui32>) -> !sve.vector<4xui32>
{
%0 = sve.ummla %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32>
return %0 : !sve.vector<4xui32>
}
func @get_vscale() -> index
{
%0 = sve.vscale : index
return %0 : index
}
Convert to LLVM Dialect with mlir-opt -convert-sve-to-llvm
:
module {
llvm.func @sve_sdot(%arg0: !llvm.vec<? x 16 x i8>, %arg1: !llvm.vec<? x 16 x i8>, %arg2: !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32> {
%0 = "llvm_sve.sdot"(%arg2, %arg0, %arg1) : (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>) -> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
llvm.func @sve_smmla(%arg0: !llvm.vec<? x 16 x i8>, %arg1: !llvm.vec<? x 16 x i8>, %arg2: !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32> {
%0 = "llvm_sve.smmla"(%arg2, %arg0, %arg1) : (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>) -> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
llvm.func @sve_udot(%arg0: !llvm.vec<? x 16 x i8>, %arg1: !llvm.vec<? x 16 x i8>, %arg2: !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32> {
%0 = "llvm_sve.udot"(%arg2, %arg0, %arg1) : (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>) -> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
llvm.func @sve_ummla(%arg0: !llvm.vec<? x 16 x i8>, %arg1: !llvm.vec<? x 16 x i8>, %arg2: !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32> {
%0 = "llvm_sve.ummla"(%arg2, %arg0, %arg1) : (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>) -> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
llvm.func @get_vscale() -> !llvm.i64 {
%0 = "llvm_sve.vscale"() : () -> !llvm.i64
llvm.return %0 : !llvm.i64
}
}
Translate to LLVM IR with mlir-translate -sve-mlir-to-llvmir
:
define <vscale x 4 x i32> @sve_sdot(<vscale x 16 x i8> %0, <vscale x 16 x i8> %1, <vscale x 4 x i32> %2) !dbg !3 {
%4 = call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4 x i32> %2, <vscale x 16 x i8> %0, <vscale x 16 x i8> %1), !dbg !7
ret <vscale x 4 x i32> %4, !dbg !9
}
define <vscale x 4 x i32> @sve_smmla(<vscale x 16 x i8> %0, <vscale x 16 x i8> %1, <vscale x 4 x i32> %2) !dbg !10 {
%4 = call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4 x i32> %2, <vscale x 16 x i8> %0, <vscale x 16 x i8> %1), !dbg !11
ret <vscale x 4 x i32> %4, !dbg !13
}
define <vscale x 4 x i32> @sve_udot(<vscale x 16 x i8> %0, <vscale x 16 x i8> %1, <vscale x 4 x i32> %2) !dbg !14 {
%4 = call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32> %2, <vscale x 16 x i8> %0, <vscale x 16 x i8> %1), !dbg !15
ret <vscale x 4 x i32> %4, !dbg !17
}
define <vscale x 4 x i32> @sve_ummla(<vscale x 16 x i8> %0, <vscale x 16 x i8> %1, <vscale x 4 x i32> %2) !dbg !18 {
%4 = call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4 x i32> %2, <vscale x 16 x i8> %0, <vscale x 16 x i8> %1), !dbg !19
ret <vscale x 4 x i32> %4, !dbg !21
}
define i64 @get_vscale() !dbg !22 {
%1 = call i64 @llvm.vscale.i64(), !dbg !23
ret i64 %1, !dbg !25
}
declare <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>) #0
declare <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>) #0
declare <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>) #0
declare <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>) #0
declare i64 @llvm.vscale.i64() #1
attributes #0 = { nounwind readnone }
attributes #1 = { nofree nosync nounwind readnone willreturn }
Compile to Aarch64 assembly:
// [...]
sve_sdot: // @sve_sdot
sdot z2.s, z0.b, z1.b
mov z0.d, z2.d
ret
// [...]
sve_smmla: // @sve_smmla
smmla z2.s, z0.b, z1.b
mov z0.d, z2.d
ret
// [...]
sve_udot: // @sve_udot
udot z2.s, z0.b, z1.b
mov z0.d, z2.d
ret
// [...]
sve_ummla: // @sve_ummla
ummla z2.s, z0.b, z1.b
mov z0.d, z2.d
ret
// [...]
get_vscale: // @get_vscale
rdvl x8, #1
lsr x0, x8, #4
ret
// [...]
One way to compile the generated LLVM IR to Aarch64 could be:
llc -march=aarch64 -mattr=v8.6a,sve
In this case, the minimum requirements for xMMLA instructions are the v8.6a
and sve
attribute flags.