An update on ArmSME dialect and proposed next steps

I’d like to share an update on the ArmSME dialect and the proposed next steps.

Current state

So far we’ve introduced a pass ‘-enable-arm-streaming’ that enables Arm Streaming SVE (SSVE) mode and access to ZA (D150934, D152695, D153050) [1], and added the first lowering from vector → ArmSME intrinsics (D152508), which lowers a vector.transfer_write of zeroes and type vector<[16x16]xi8> to the SME zero {za} instruction [2], which zeroes the entire ZA, and then writes it out to memory with the str instruction [2]. D154867 (in-flight) updates this lowering to use custom ops that sit on top of the intrinsics.

Supporting load / store

I’ve been looking at supporting load/store to/from ZA for more tile sizes, which has a few challenges. The first being the SME intrinsics have only inputs and no outputs, so Vector dialect ops that have outputs cannot be replaced by these intrinsics if they have uses, unless the uses are also replaced as part of the same lowering. These kind of “global” rewrites could quickly get quite complex.

A further issue is the intrinsics take an 32-bit integer constant tile id for the ZA tile to use, so a mechanism is required to select this.

To solve this, together with my colleagues @paulwalker-arm and @banach-space we came up with three new (custom) ops:

  1. arm_sme.get_tile_id - returns a scalar integer representing an SME “virtual tile” that is not in use.
  2. arm_sme.cast_tile_to_vector - casts from a tile id to a 2-d scalable vector type, which represents an SME “virtual tile”.
  3. arm_sme.cast_vector_to_tile - casts from a 2-d scalable vector type, which represents an SME “virtual tile”, to a tile id.

As well as a further op arm_sme.tile_load to load a ZA tile from memory.

The following example shows how these would be used when lowering from Vector dialect.

Example

1. Input (Vector dialect)

%tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

2. After convert Vector to ArmSME (-convert-vector-to-arm-sme)

%tile = arm_sme.tile_load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

3. After lowering arm_sme.tile_load to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”)

%tile_id = arm_sme.get_tile_id : i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

The arm_sme.tile_load can’t be replaced with an SME intrinsic that has no outputs since it is used by the arm_sme.tile_store. However, by inserting a arm_sme.cast_tile_to_vector op after the load intrinsics the arm_sme.tile_load can be replaced. This enables “local” rewrites on individual custom ArmSME ops, rather than “global” rewrites that would have to look at the uses and also lower them.

4. After lowering arm_sme.tile_store to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”)

%tile = arm_sme.tile_load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
%tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

5. Complete lowering

%tile_id_0 = arm_sme.get_tile_id : i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id_0, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%tile_id_1 = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id_1, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

Later steps

Canonicalization could then look through arm_sme.cast_tile_to_vector and fold the cast away if it comes from a arm_sme.cast_vector_to_tile, and vice-versa.

After canonicalization (-canonicalize)

%tile_id = arm_sme.get_tile_id : i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

After tile allocation (-allocate-arm-sme-tiles)

%tile_id = arith.constant 0 : i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

Before lowering to LLVM arm_sme.get_tile_id would be replaced with an actual tile number (32-bit integer constant), this is the tile allocation problem that has been discussed previously. For this we propose a pass ‘-allocate-arm-sme-tiles’ that implements allocation of SME ZA tiles at the function level. At the beginning of a function it would assume a clean slate, this is valid since when PSTATE.ZA is changed by any means from 0 to 1, all implemented bits of the SME ZA storage are set to zero (B1.1.1.2, RYRZRM, [1]). It’s also assumed that ZA state cannot cross the function boundary.

Summary

To summarise, the pipeline would be as follows:

  1. -enable-arm-streaming=“mode=locally enable-za” - (turn SSVE on and ZA access)
  2. -convert-vector-to-arm-sme (Vector dialect → custom ArmSME ops)
  3. -convert-vector-to-llvm=“enable-arm-sme” (custom ArmSME ops → SME intrinsics)
  4. -allocate-arm-sme-tiles (Replace arm_sme.get_tile_id with actual tiles)

I shared two patches earlier in the week that implement part of this proposal:

  • D154941 - Add custom get_tile_id and cast ops
  • D154955 - Implement tile allocation

and I’ll share a further patch that implements load / store shortly.

We believe the interface provided by these custom ops will enable more complex examples.

The patches listed in this update have been under review for a few days now and so far the feedback has been positive. I wanted to share the update on Discourse as well because these are uncharted waters for us and we want to make sure that we are not missing anything obvious.

We’d love to hear thoughts from the community on this, let us know if you have any questions or comments.

Thanks for reading!

[1] Documentation – Arm Developer
[2] Documentation – Arm Developer

6 Likes

Load / store patch - ⚙ D155306 [mlir][ArmSME] Add tile load op and extend tile store tile size support

1 Like