[PSA] ArmSME lowering pipeline and tile allocation changes

TLDR

We are changing the ArmSME lowering pipeline. The TLDR is:

  • -allocate-arm-sme-tiles will no longer be a standalone pass
    • (though for testing purposes we plan to have a -test-arm-sme-tile-allocation pass)
  • -convert-arm-sme-to-llvm must happen after -convert-scf-to-cf
    • SME tile allocation will be done automatically as part of the conversion

The only changes needed in any existing uses will be to remove the tile allocation pass and ensure that -convert-arm-sme-to-llvm is done after SCF is lowered to control flow.

End of TLDR.

Context

This is part of an effort to improve the correctness/robustness of the ArmSME dialect, with how tile values (e.g. values of vector<[4]x[4]xf32>) are mapped on tile IDs (aka SME tile registers).

History

The history is, first we modeled tile IDs as MLIR values (via arm_sme.get_tile_id ops that were simply replaced with constants):

%tile_id = arm_sme.get_tile_id : i32
%tile = arm_sme.cast_tile_id_to_vector %tile_id : vector<[4]x[4]xf32>
scf.for %i = %c0 to %c10 step {
  // Not SSA! Mutates %tile in place!
  arm_sme.move_vector_to_tile_slice %some_vector, %tile[%i] : vector<[4]xf32> into vector<[4]x[4]xf32>
}

However, this had quite a few issues. It only worked for programs not in SSA form, so relied on all ArmSME operations to have memory side effects. IR in SSA form (e.g. loops with iter_args) could not be lowered as tile IDs appeared non-constant.

There were also other standard allocation issues like tiles would never be released, as the allocator had no understanding of the program (or the control flow), it simply replaced get_tile_id ops with arith.constants.

This made lowering to the ArmSME dialect tricky, as you would go from dialects like arith and vector, where operations are pure and side-effect free to a dialect where every operation has side effects, and that can lead to correctness issues.


The first step in improving this was switching to use attributes for tile IDs, now get_tile_id was no more, and ArmSME operations could be (mostly) correctly used in SSA IR.

%initTile = arm_sme.get_tile : vector<[4]x[4]xf32>
%newTile = scf.for %i = %c0 to %c10 step iter_args(%tile = %initTile) -> vector<[4]x[4]xf32> {
  %newTile = arm_sme.move_vector_to_tile_slice %some_vector, %tile[%i] : vector<[4]xf32> into vector<[4]x[4]xf32>
  scf.yield %newTile : vector<[4]x[4]xf32>
}

However, the SSA form was merely surface level, the allocator did not use it to reason about values, it would simply walk the IR and assign tile IDs. This had the same underlying issues as before (implicit side-effects, not releasing tiles, etc.), along with basic correctness issues from how tile IDs were assigned. So, lowering to ArmSME was a little nicer as programs could be in a standard form, but there were still known issues.

Next Improvement

The next step in improving this is to analyze the program to allocate the tiles. The plan is to use MLIR’s liveness analysis on ArmSME operations after lowering to control flow (as control flow has much simpler semantics than dialects like SCF). Using the liveness information, we can construct live ranges and implement a simple linear scan register allocator. Note: This is not too far off how PDL allocates bytecode memory indices (which is one of the closest things to a register allocator in MLIR).

By integrating this into the ArmSME -> LLVM conversion we can allow high-level (value-based) ArmSME operations to be side-effect free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation).

This also comes with other general improvements. For example, by using liveness information releasing tiles after they are no longer in use is automatic (as you would expect from a compiler).

The hope is for ArmSME operations to have no hidden state/side-effects and allow easily lowering dialects such as vector and arith to SME, without making assumptions about how the input IR looks.

The goal is to ensure correctness, so we have a base for working on optimizations.

Patches

We already have this implemented on a branch, and tested it on our e2e/integration tests, some interesting cases, and an internal mixed ML/CV model (with IREE).

We have two patches available now:

3 Likes

Thanks Ben!

CC @nujaa @frank_gao @dcaballe @zhaoshiz

This is unlikely to impact folks not using the ArmSME dialect, hence the CC list is rather short. For those of you that might be impacted, please let us know if you have any concerns.

-Andrzej

1 Like