I recently shared an update on the ArmSME dialect and the proposed next steps, along with several patches that have now landed:
- D154941 - Add custom get_tile_id and cast ops
- D154867 - Introduce custom ops for SME
- D154955 - Implement tile allocation
- D155306 - Add tile load op and extend tile store tile size support
On D154867 Diego pointed out we’re materializing a loop when lowering to intrinsics:
Something important here: we introduce the SME lowering layer to explicitly model what is needed for SME and make the conversion to LLVM easier. However, here we are materializing a loop. I’m wondering why that loop is not generated when we move from Vector to the SME dialect and then the conversion to LLVM is mostly a 1:1 translation to the intrinsics.
It sounds good to me to do this separately but this is a big abstraction change so hopefully we can do it sooner than later. If you think the non-loop abstraction is also useful, we could also have two level of abtractions within the same dialect, where we go first to the non-loop one and then materialize the loop at some point within the SME dialect. The Vector dialect is a good example of this.
The reason for this is the load and store instructions [1] in SME load or store tile slices, which are 1d vectors of horizontally or vertically contiguous elements within a ZA tile. To load or store an entire tile we emit a loop when lowering to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”
), for example:
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
"arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
To resolve this, we’ve (@c-rhodes, @paulwalker-arm, @banach-space) come up with two new ops that map 1-1 with the intrinsics:
arm_sme.load_tile_slice_and_update
- Loads a 1D tile slice from memory into a 2D SME “virtual tile”.arm_sme.store_tile_slice
- Stores a 1D tile slice from a 2D SME “virtual tile” into memory.
As well as a new conversion pass -convert-arm-sme-to-scf
to materialize loops with these ops.
The following example (based on the example from the previous update) shows how these would be used.
Example
1. Input (Vector dialect) [UNCHANGED]
%tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
2. After convert Vector to ArmSME (-convert-vector-to-arm-sme) [UNCHANGED]
%tile = arm_sme.tile_load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
3. After lowering arm_sme.tile_load
to SCF + ArmSME (-convert-arm-sme-to-scf) [NEW]
%tile_id = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
%tile_update = arm_sme.tile_slice_load_and_update %mem1[%tile_slice_index], %tile_init, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
}
%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
The arm_sme.cast_tile_to_vector
cast is emitted after the loop so the canonicalizations don’t have to look through the loops.
4. After lowering arm_sme.tile_store
to SCF + ArmSME (-convert-arm-sme-to-scf) [NEW]
%tile = arm_sme.tile_load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
arm_sme.tile_slice_store %tile, %tile_slice_index, %mem2[%tile_slice_index] : memref<?x?xi32>, vector<[4]x[4]xi32>
}
5. After lowering arm_sme.tile_slice_load_and_update
to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”) [NEW]
%tile_id_0 = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
%tile_id_1 = arm_sme.cast_vector_to_tile %tile_init : vector<[4]x[4]xi32> to i32
"arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id_1, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
%tile_update = arm_sme.cast_tile_to_vector %tile_id_1 : i32 to vector<[4]x[4]xi32>
}
%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
arm_sme.tile_slice_store %tile, %tile_slice_index, %mem2[%tile_slice_index] : memref<?x?xi32>, vector<[4]x[4]xi32>
}
The final cast uses the original %tile_id_0
, the dataflow isn’t perfectly modelled but as long as the ops are not reordered the semantics are sound. Side-effects can prevent this.
6. After lowering arm_sme.tile_slice_store
to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”) [NEW]
%tile_id_0 = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
%tile_update = arm_sme.tile_slice_load_and_update %mem1[%tile_slice_index], %tile_init, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
}
%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
%tile_id_1 = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
"arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id_1, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
7. Complete lowering
%tile_id_0 = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
%tile_id_1 = arm_sme.cast_vector_to_tile %tile_init : vector<[4]x[4]xi32> to i32
"arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id_1, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
%tile_update = arm_sme.cast_tile_to_vector %tile_id_1 : i32 to vector<[4]x[4]xi32>
}
%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
%tile_id_2 = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
"arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id_2, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
8. After canonicalization (-canonicalize) [UNCHANGED]
%tile_id = arm_sme.get_tile_id : i32
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
"arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
"arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
9. After tile allocation (-allocate-arm-sme-tiles) [UNCHANGED]
%tile_id = arith.constant 0 : i32
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
"arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
// ...
"arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
Summary
To summarise, the pipeline would be as follows:
- -enable-arm-streaming=“mode=locally enable-za” - (turn SSVE on and ZA access)
- -convert-vector-to-arm-sme (Vector dialect → custom ArmSME ops)
- -convert-arm-sme-to-scf (custom ArmSME ops → SCF + custom ArmSME ops) (NEW)
- -convert-vector-to-llvm=“enable-arm-sme” (custom ArmSME ops → SME intrinsics) (UPDATED)
- -allocate-arm-sme-tiles (Replace arm_sme.get_tile_id with actual tiles)
this will enable lowering ArmSME ops to ops that match 1-1 with the intrinsics where necessary. The MOVA tile-to-vector and vector-to-tile instructions also operate on tile slices and could use the same abstraction. The latter will be useful for lowering vector.broadcast
.
I’ve shared a patch ⚙ D156467 [mlir][ArmSME] Add conversion from ArmSME to SCF to materialize loops on Phabricator implementing this but thought I would share a more in-depth discussion of the approach here. As always, would appreciate any thoughts from the community on this, let us know if you have any questions or comments.
Cheers!
Cullen