Loop materialization in ArmSME

I recently shared an update on the ArmSME dialect and the proposed next steps, along with several patches that have now landed:

  • D154941 - Add custom get_tile_id and cast ops
  • D154867 - Introduce custom ops for SME
  • D154955 - Implement tile allocation
  • D155306 - Add tile load op and extend tile store tile size support

On D154867 Diego pointed out we’re materializing a loop when lowering to intrinsics:

Something important here: we introduce the SME lowering layer to explicitly model what is needed for SME and make the conversion to LLVM easier. However, here we are materializing a loop. I’m wondering why that loop is not generated when we move from Vector to the SME dialect and then the conversion to LLVM is mostly a 1:1 translation to the intrinsics.

It sounds good to me to do this separately but this is a big abstraction change so hopefully we can do it sooner than later. If you think the non-loop abstraction is also useful, we could also have two level of abtractions within the same dialect, where we go first to the non-loop one and then materialize the loop at some point within the SME dialect. The Vector dialect is a good example of this.

The reason for this is the load and store instructions [1] in SME load or store tile slices, which are 1d vectors of horizontally or vertically contiguous elements within a ZA tile. To load or store an entire tile we emit a loop when lowering to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”), for example:

scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

To resolve this, we’ve (@c-rhodes, @paulwalker-arm, @banach-space) come up with two new ops that map 1-1 with the intrinsics:

  1. arm_sme.load_tile_slice_and_update - Loads a 1D tile slice from memory into a 2D SME “virtual tile”.
  2. arm_sme.store_tile_slice - Stores a 1D tile slice from a 2D SME “virtual tile” into memory.

As well as a new conversion pass -convert-arm-sme-to-scf to materialize loops with these ops.

The following example (based on the example from the previous update) shows how these would be used.

Example

1. Input (Vector dialect) [UNCHANGED]

%tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

2. After convert Vector to ArmSME (-convert-vector-to-arm-sme) [UNCHANGED]

%tile = arm_sme.tile_load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

3. After lowering arm_sme.tile_load to SCF + ArmSME (-convert-arm-sme-to-scf) [NEW]

%tile_id = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  %tile_update = arm_sme.tile_slice_load_and_update %mem1[%tile_slice_index], %tile_init, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
}
%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %mem2[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

The arm_sme.cast_tile_to_vector cast is emitted after the loop so the canonicalizations don’t have to look through the loops.

4. After lowering arm_sme.tile_store to SCF + ArmSME (-convert-arm-sme-to-scf) [NEW]

%tile = arm_sme.tile_load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  arm_sme.tile_slice_store %tile, %tile_slice_index, %mem2[%tile_slice_index] : memref<?x?xi32>, vector<[4]x[4]xi32>
}

5. After lowering arm_sme.tile_slice_load_and_update to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”) [NEW]

%tile_id_0 = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%num_tile_slices  = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  %tile_id_1 = arm_sme.cast_vector_to_tile %tile_init : vector<[4]x[4]xi32> to i32
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id_1, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
  %tile_update = arm_sme.cast_tile_to_vector %tile_id_1 : i32 to vector<[4]x[4]xi32>
}

%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>

scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  arm_sme.tile_slice_store %tile, %tile_slice_index, %mem2[%tile_slice_index] : memref<?x?xi32>, vector<[4]x[4]xi32>
}

The final cast uses the original %tile_id_0, the dataflow isn’t perfectly modelled but as long as the ops are not reordered the semantics are sound. Side-effects can prevent this.

6. After lowering arm_sme.tile_slice_store to intrinsics (-convert-vector-to-llvm=“enable-arm-sme”) [NEW]

%tile_id_0 = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  %tile_update = arm_sme.tile_slice_load_and_update %mem1[%tile_slice_index], %tile_init, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
}

%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>

scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  %tile_id_1 = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id_1, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

7. Complete lowering

%tile_id_0 = arm_sme.get_tile_id : i32
%tile_init = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>
%num_tile_slices = arith.muli %c4, %vscale : index
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  %tile_id_1 = arm_sme.cast_vector_to_tile %tile_init : vector<[4]x[4]xi32> to i32
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id_1, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
  %tile_update = arm_sme.cast_tile_to_vector %tile_id_1 : i32 to vector<[4]x[4]xi32>
}
%tile = arm_sme.cast_tile_to_vector %tile_id_0 : i32 to vector<[4]x[4]xi32>

scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  %tile_id_2 = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id_2, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

8. After canonicalization (-canonicalize) [UNCHANGED]

%tile_id = arm_sme.get_tile_id : i32
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

9. After tile allocation (-allocate-arm-sme-tiles) [UNCHANGED]

%tile_id = arith.constant 0 : i32
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %tile_slice_index) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

Summary

To summarise, the pipeline would be as follows:

  1. -enable-arm-streaming=“mode=locally enable-za” - (turn SSVE on and ZA access)
  2. -convert-vector-to-arm-sme (Vector dialect → custom ArmSME ops)
  3. -convert-arm-sme-to-scf (custom ArmSME ops → SCF + custom ArmSME ops) (NEW)
  4. -convert-vector-to-llvm=“enable-arm-sme” (custom ArmSME ops → SME intrinsics) (UPDATED)
  5. -allocate-arm-sme-tiles (Replace arm_sme.get_tile_id with actual tiles)

this will enable lowering ArmSME ops to ops that match 1-1 with the intrinsics where necessary. The MOVA tile-to-vector and vector-to-tile instructions also operate on tile slices and could use the same abstraction. The latter will be useful for lowering vector.broadcast.

I’ve shared a patch ⚙ D156467 [mlir][ArmSME] Add conversion from ArmSME to SCF to materialize loops on Phabricator implementing this but thought I would share a more in-depth discussion of the approach here. As always, would appreciate any thoughts from the community on this, let us know if you have any questions or comments.

Cheers!

Cullen

[1] Documentation – Arm Developer

2 Likes

This looks fantastic to me, thanks! The lowering to LLVM IR becomes really simple and we now have two layers within SME that will allow us to apply transformations at the most appropriate level of abstraction. Do you see value in preserving arm_sme.tile_load/store instead of going directly to the scf.for representation from Vector?

This looks fantastic to me, thanks! The lowering to LLVM IR becomes really simple and we now have two layers within SME that will allow us to apply transformations at the most appropriate level of abstraction.

Thanks Diego! Appreciate you pushing for this.

Do you see value in preserving arm_sme.tile_load/store instead of going directly to the scf.for representation from Vector?

That’s a good point, I can’t see any reason why we couldn’t do that. The only question that comes to mind is what would happen to -convert-arm-sme-to-scf. I suppose that would move to -convert-vector-to-arm-sme?

1 Like

Yep, that would make sense to me. Don’t get me wrong, I like proper layering but only when there is a purpose behind each specific layer. If that is not the case, it’s better to remove the layer to remove one transformation step, improve compile time and simplify the abstraction in general. Target-specific transformations should help with this.

Thanks for being so open to considering feedback! I’m really excited about how this is evolving!