Here I have the initial implementation of an ArmSME dialect: ⚙ D139875 [RFC][MLIR][ArmSME] Initial implementation of ArmSME Dialect, any feedback would be welcome.
The Scalable Matrix Extension (SME) is an extension to SVE (scalable vector extension) for aarch64, and focuses on outer product instructions to accelerate matrix multiplies by utilizing a 2D tile register (ZA), which is split into multiple smaller square tiles (ZA[0-3]s, ZA[0-7]d). A 64-bit (double-word) tile can be seen as an opaque square vector<[2x2]xf64/si64/ui64>
, whereas a 32-bit tile can be seen as an opaque square vector<4x4xf32/si32/ui32>
.
This implementation enumerates the register tiles as an opaque value, as opposed to an first-class-type, because I think the vector
dialect already provides enough of an abstraction for 2D tiles, and so that this can better match the LLVM implementation of these instruction/intrinsics. I have debated on whether or not this would be better suited to go into the ArmSVE dialect seeing how it is an extension of it, but decided against it since most of the functionality we need from it are built into the Vector dialect.
More information on the architecture itself can be found here: Documentation – Arm Developer
Currently this patch defines most of the instructions defined by the extension, but lowering only supports non-widening (aka. fp32 and fp64) versions of MOPA/MOPS op, in addition to the ZERO op.
Here’s an example of the new ops I have defined:
func.func @arm_sme_ops(%0 : vector<[2]xf64>,
%1 : vector<[4x2]xf16>,
%2 : vector<[2x4]xsi16>) {
%c = arith.constant 128 : index
%pred.64 = vector.create_mask %c : vector<[2]xi1>
%pred.16 = vector.create_mask %c, %c : vector<[4x2]xi1>
%pred.i16 = vector.create_mask %c, %c : vector<[2x4]xi1>
// Clears the za0d and za2d 64-bit tiles and the za2s 32-bit tile
arm_sme.zero za0d, za2d, za1s
// Accumulates into a SVLd x SVLd tile register named za0d (d for double-word)
arm_sme.mopa za0d, %pred.64, %pred.64, %0, %0 : vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>
// Widening outer product, can be seen as two outer products of <[4x1]xf16> vectors accumulated together into a f32 tile.
arm_sme.mopa za1s, %pred.16, %pred.16, %1, %1 : vector<[4x2]xi1>, vector<[4x2]xi1>, vector<[4x2]xf16>, vector<[4x2]xf16>
// Integer widening outer product, i16 accumulates into i64, has to be signed or unsigned (not signless), both operands can be of different signedness
arm_sme.mopa za1d, %pred.i16, %pred.i16, %2, %2 : vector<[2x4]xi1>, vector<[2x4]xi1>, vector<[2x4]xsi16>, vector<[2x4]xsi16>
The plan is to somehow connect to the vector dialect either through the OuterProductOp
or by introducing a MaskedOuterProductOp
. Additionally accessing vectors from within the SME tile register should be implemented through the new load/store/move instructions.