[RFC] Creating a ArmSME Dialect

Here I have the initial implementation of an ArmSME dialect: ⚙ D139875 [RFC][MLIR][ArmSME] Initial implementation of ArmSME Dialect, any feedback would be welcome.

The Scalable Matrix Extension (SME) is an extension to SVE (scalable vector extension) for aarch64, and focuses on outer product instructions to accelerate matrix multiplies by utilizing a 2D tile register (ZA), which is split into multiple smaller square tiles (ZA[0-3]s, ZA[0-7]d). A 64-bit (double-word) tile can be seen as an opaque square vector<[2x2]xf64/si64/ui64>, whereas a 32-bit tile can be seen as an opaque square vector<4x4xf32/si32/ui32>.

This implementation enumerates the register tiles as an opaque value, as opposed to an first-class-type, because I think the vector dialect already provides enough of an abstraction for 2D tiles, and so that this can better match the LLVM implementation of these instruction/intrinsics. I have debated on whether or not this would be better suited to go into the ArmSVE dialect seeing how it is an extension of it, but decided against it since most of the functionality we need from it are built into the Vector dialect.

More information on the architecture itself can be found here: Documentation – Arm Developer

Currently this patch defines most of the instructions defined by the extension, but lowering only supports non-widening (aka. fp32 and fp64) versions of MOPA/MOPS op, in addition to the ZERO op.

Here’s an example of the new ops I have defined:

func.func @arm_sme_ops(%0 : vector<[2]xf64>,
                       %1 : vector<[4x2]xf16>,
                       %2 : vector<[2x4]xsi16>) {
  %c = arith.constant 128 : index
  %pred.64 = vector.create_mask %c : vector<[2]xi1>
  %pred.16 = vector.create_mask %c, %c : vector<[4x2]xi1>
  %pred.i16 = vector.create_mask %c, %c : vector<[2x4]xi1>

  // Clears the za0d and za2d 64-bit tiles and the za2s 32-bit tile
  arm_sme.zero za0d, za2d, za1s

  // Accumulates into a SVLd x SVLd tile register named za0d (d for double-word)
  arm_sme.mopa za0d, %pred.64, %pred.64, %0, %0 : vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>

  // Widening outer product, can be seen as two outer products of <[4x1]xf16> vectors accumulated together into a f32 tile.
  arm_sme.mopa za1s, %pred.16, %pred.16, %1, %1 : vector<[4x2]xi1>, vector<[4x2]xi1>, vector<[4x2]xf16>, vector<[4x2]xf16>

  // Integer widening outer product, i16 accumulates into i64, has to be signed or unsigned (not signless), both operands can be of different signedness
  arm_sme.mopa za1d, %pred.i16, %pred.i16, %2, %2 : vector<[2x4]xi1>, vector<[2x4]xi1>, vector<[2x4]xsi16>, vector<[2x4]xsi16>

The plan is to somehow connect to the vector dialect either through the OuterProductOp or by introducing a MaskedOuterProductOp. Additionally accessing vectors from within the SME tile register should be implemented through the new load/store/move instructions.

1 Like

Seems odd that you had to use register names at this high level. If those are just interchangeable registers, then register allocation or library implementation should be the ones choosing them. If they’re ABI registers, then code generation should take care before any of that runs.

I imagine today there’s only one reasonable implementation, but as time passes, new products could have different implementations (multiple pipelines, longer vectors) and the trade-offs won’t be the same, so the IR may not even lower (non-divisible ranges, for ex.) or just generate really poor code.

To me, having predicate types, data types, etc makes sense. Directly referencing registers, doesn’t. Unless I’m missing something here…

2 Likes

Thanks for this proposal!

I am not aware of any existing implementations of SME. IIUC, we won’t be able to do any end-to-end testing in the near future. Unless I missed some announcement :slight_smile:

Well, what I mean is: the extension is designed in a way that matmul implements as a known sequence of ops in a known tile size, so as long as you tile correctly the core code always use the same registers.

The TPP paper proposes one such tiling algorithm for multiple extensions (p9). We could easily add SME to that list, as I imagine it matches existing tile registers in other extensions + outer product etc.

But in the future, comes along another CPU, say ARMv11, with SME2 and with slightly different proposal. Or even before that, some ARMv10 with SME but now with two units instead of one or the usage of register is different, then you lowered in IR a register that isn’t a good match.

That’s why we leave registers for MIR in LLVM, not LLVM IR.

So, I may be missing something, but I can’t understand the reason to have register names in an IR that it’s even higher than LLVM IR.

I’d also like to pause and reflect on the proliferation of dialects for hardware extensions… What is the actual value that this is bringing? How are we supposed to make general statements about IR passes if all those dialects are so low-level?

Can we have dialects that describe the general ideas, allowing for special cases, and then match those ideas to the concepts in each extension as a lowering phase? I mean, most vector/matrix extensions have basically the same register types and very similar ops, why duplicate them in this manner?

Back in the vector days in LLVM, there was a big discussion about the tradeoffs between intrinsics and generic IR patterns. I think the conclusions back then are still valid today: generic patterns are more expressive and we only have intrinsics for the things that it’s impossible to represent in generic patterns.

So, perhaps (thinking out loud) we should think about moving most of the existing ops in AVX/SVE/NEON to a vector/matrix dialect and avoid creating special types, or individual registers, in the process?

At least Intel is shipping something similar: AMX.

Do you really need to hardcode the register names or is it an inherent feature of ZA?

There is SME2.1.

1 Like

I agree in that directly listing register names doesn’t make much sense at this level, and I would like to implement this in a better way if possible, however I would like to give some context and my thoughts during implementing this…

In the current existing LLVM implementation of SME (e.g. ⚙ D127843 [AArch64][SME] Add the zero intrinsic), the matrix tiles are not implemented as a register but as a immediate argument to the SME intrinsics. (IIRC there did exist another implementation of SME with <mscale x 4 x f64> for tile registers which handled RA and spilling and such, but that did not get accepted?)

We could implement these tiles as a 2D vector type in this dialect, however that would also mean implementing functions similar to register allocation into MLIR as well, which I don’t think is a great idea either. I admit that there is probably a better way of implementing this SME tile type, and I would be open to adapting any ideas that make sense.

As to the purpose of this dialect altogether (as it stands currently), there are certain higher level operations (not just matmul) that could be accelerated by utilizing the outer product instructions in SME, and there is no clear way to lower these to SME/LLVM with the existing functionalities. Also, if we were to extend the vector dialect to support these functions, there will still need to be an interface to lower to LLVM IR intrinsics etc. directly, and I would say these probably make sense in its own dialect as opposed to bloating up the LLVMIR dialect…

This has to be a sequence of “known patterns” and not hard-code up here.

So, you lower that as something like a 2D tensor/memref, then it lowers to vectors in LLVM which will be recognised as the registers you want at lowering.

If LLVM IR already has constructs for this, then you just need to make sure those memrefs lower to the right LLVM Dialect types when lowering the SME ops operands and result types.

So we can potentially do something like this:

// Get an initialized tile, but would require a redundant "zero" 
// instruction if the tile is loaded from memory
%tile = arm_sme.zero : <[4x4]xf32>

// This would technically not be SSA
arm_sme.fmopa %tile, ... : <[4x4]xf32>, ...

// OR ===================================================

// If some other operation were to use the original %tile
// then we will have to copy the %tile vector, and if we run
// out of tiles then we will have to spill to memory...
%newTile = arm_sme.fmopa %tile, ...

I believe there are two challenges to solve.

  • For the ArmSME dialect you have to find a representation of ZA that is close to the real hardware and can be lowered to LLVM with reasonable effort.

  • The vector dialect now supports fixed length and scalable vectors without tying itself to hardware. It should also learn some high-level tiled-matrix concept.

I added the AMX dialect a while back, and tried to find common ground between the architectural neutral aspects of the vector dialect and the architectural specifics of the AMX extensions (see this 2021 ODM presentation).

In particular, I tried to avoid going all the way down to register level, staying within the vector dialect as follows:

%4 = amx.tile_zero : vector<16x16xf32>
amx.tile_store %arg2[%i, %j], %4
: memref<128x128xf32>, vector<16x16xf32>

lowers ultimately into

tilezero %tmm0
leaq (%r10,%rcx), %rdi
tilestored %tmm0, (%rdi,%r11)
1 Like

I don’t see the zero instruction as redundant, I see as semantics. What happens in IR doesn’t necessarily be what goes in hardware.

It would be great to see a study between AMX and SME and fuse the implementation of both into the same matrix dialect (to compose with the vector dialect?).

It would also be good to merge SVE/AVX into one, using vector types, but that may me a larger fish to fry, and is not relevant to this thread.

Just as another factoid for background, the vector dialect is a bit of a misnomer, since it really deals with n-dimensional SSA values (as Nicolas stated it, higher-order vectors). This is mainly the result of the fact that the more appropriate name “tensor” was already claimed by the tensor dialect. In retrospect, perhaps other names would have been better (SIMD dialect?), but the generalization to at least 2-dim matrices was already there from the start. So hopefully both AMX and SME will compose nicely with the 2-D “vector” dialect.

The issue is that SME is the scalable matrix extension while AMX are fixed sized tiles matrices.

Right, so this would be a mix of arm_sve and amx extensions, which isn’t a trivial exercise.

I don’t want to derail this RFC, I just really don’t get what’s the point of having such low-level dialects in MLIR (and that includes the existing AVX/NEON/SVE ones). LLVM already has a good enough abstraction for those, so lowering to LLVM dialect would make sense here.

The only reason why we would need those dialects is to do further transformations in them, but if now we create a new dialect to every extension, really, how much can we do at high level that isn’t (or shouldn’t be) replicated in LLVM already?

Yeah, I was trying to name the op dialect matrix using the vector dialect types, but because those names are overloaded, it’s hard to know what is what. But SIMD or similar may be a better naming for an ops dialect.

And adding scalability to vector could be a way to join them all in one thing?

That is, if it does make sense to have such dialects in MLIR in the first place, which is what I’m questioning. How much vectorization do we really want to duplicate here, from LLVM? Are we trying to create another compiler entirely, and if so, does it still make sense for MLIR to be in LLVM?

(that escalated quickly, but you get the idea, our decisions need to take in the long term cost, and I’m not seeing long term benefits for any of these very low level “dialects” in MLIR).

Random thought: how does this all compose with the LLVM matrix intrinsics? LLVM Language Reference Manual — LLVM 16.0.0git documentation

Do those get good performance in practice or is there a reason we are avoiding lowering through them and going lower level?

Personally I don’t mind having these “lower level” hardware-specific dialects. Just we need to be careful about responsibilities like register allocation / scheduling and how all this composes with lowering through LLVM (or not?).

2 Likes

First, I wanted to say that I am very exciting to see this being proposed :slight_smile: I should also mention that I am a bit biased - I work at Arm (though I’ve not really worked on SME).

That’s already available, right? (see vector.vscale).

IIUC, the rationale was to leverage the knowledge of the underlying ISA to generate better code while still having the access to the context only available in the frontend. Is this how the AVX/NEON/SVE dialects are used ATM? I am not sure, but I haven’t worked much in this area.

However, I think that SME is “different” enough that having a dedicated MLIR dialect is going to be crucial to generate good code. Put differently, I expect there to be more low hanging fruits with SME than with e.g. NEON/SVE.

This is a good question. I would also like to better understand how other, similar dialects in MLIR are to inform the design of ArmSME. As pointed out by others:

  • the AMX dialect - another matrix extension with a dedicated dialect,
  • the SVE dialects - another extension for “scalable” vectors with a dedicated dialect.

Given that this is a “scalable” extension, is scalable vectorization in MLIR ready for this? AFAIK, the sparse compiler that @aartbik has been developing is the only framework that has a vectorizer capable of generating scalable vectors ATM.

Perhaps I am getting ahead of myself, but it would be great to understand how to get there and how does ArmSME dialect fit here. I’m confident it does, it’s just not yet clear “how” :slight_smile:

-Andrzej

I could find: AMX. ArmNeon, ArmSVE, and X86Vector dialects.

Thanks for the proposal @frank_gao , and all the great work; looks very interesting. (Note: I work for Arm).

In general I feel this is timely as more and more HW extensions (and not only) operate on 2d tiles/matrices and worth at least discussing on possible abstractions or on how to proceed under such scenarios.

Can understand the usage of the ZA semantics in your modelling but indeed generalising and being able to more robustly specify semantics as the architecture get extended is important. Things like ZA(tile) slicing, operations with vectors of vectors on top of the ZA array etc. Also there might be some layering issues that we might need to resolve around SVE and SME. SME does have an SVE state attached to is but it operates on a different mode called “Streaming Mode” where the Vector Length could be different compared to the default one. There are start/end semantics to indicate when we operate on streaming mode and when not; do we need/have to model this?

Yes scalability is there in vector and SVE is composed on top of it. My assumption would be to follow something along these lines for SME as well? Although SME introduces agnosticism on 2 dimensions as ZA tile size is [SVL X SVL] where SVL is Streaming Vector Length.

Overall, I do personally think that these lower level dialects do serve their purpose by enabling transformations leveraging more contextual information but @_sean_silva is right that we need to be careful and draw a line on what their responsibilities are.

Great proposal, Frank :slight_smile:

I do wonder, though, if we really need to split this on its own dialect. When I wrote the SVE dialect, the original purpose was for it to hold all these arch-specific instructions that don’t have an architecture-independent mapping in LLVM IR (and to include a scalable vector type, but that’s not an issue here any more). I understand that, within Arm, SVE, SVE2, and SME are different extensions, but I don’t see any real reason to separate them in different dialects in MLIR. These dialects are not supposed to work as backends, they’re not meant to include the whole ISA, they’re just an exit for specialized vector instructions. I believe the original dialect can serve this “exit” purpose for all scalable extensions, and it will avoid having one dialect per extension, which can grow fast. Individually, these dialects would only have a handful of ops anyway, even if you put them all together, you’d end up with a relatively small dialect.

If you want to avoid the confusion with internal naming conventions, it may be worth considering renaming the dialect to “ArmVector”, not unlike what we have for x86 with “X86Vector”. For the same reasons, merging this dialect with “ArmNeon” might be something worth considering, but I’ll leave that discussion for the future.

For those who may find the purpose of these dialects confusing, the way I’ve been using ArmSVE downstream is by adding passes like:

-arm-sve-vls-vector-contract-to-mmla

It goes over all vector.contract ops with the right semantics and replaces them with arm_sve.smmla (plus glue ops). I got from things like linalg.matmul down to these vector.contract using passes that already existed in MLIR. Conveniently, there was also one that lowered linalg.matmul to vector.outerpoduct, which maps very nicely to SME’s mopa.

In this scenario, the way you control which extension to target (SVE/SVE2/SME) is at pass level.

I hope this helps! :smiley:

2 Likes