RFC: Representing register data layout explicitly in the IR

Preliminaries

Matrix multiply accumulation instructions on GPUs have constraints on which threads can hold which matrix fragments of the A, B, C and D matrices in their registers. For example, AMD GPUs require the data to be arranged in the following format (AMD matrix cores (amd-lab-notes) - AMD GPUOpen). Other hardware vendors have similar constraints. In the following we will be referring to these constraints as a ‘layout’ but the concept is actually quite low level and is more akin to the configuration of certain register blocks needed to satisfy ISA constraints (as opposed to ‘layout’ as more commonly used higher in the stack)

The layout will not only stem from these MMA instructions. One can also impose restrictions on certain operators that can be represented as layout. As an example, if we wanted to load 128 bits of data from global to shared memory then we could impose this constraint as a layout on the vector.transfer_read operator.

Problem Statement

We would like to represent this layout explicitly in the IR so we can reason about the layouts of the operands and results of such operators as well as any neighboring operators that could reuse the data present in the registers. For example consider the IR shown below for flash attention.

%9 = vector.transfer_read %alloc[%c0, %8, %c0], %cst_4 {in_bounds = [true, true]} : memref<1x128x64xf16, #gpu.address_space<workgroup>>, vector<32x64xf16>
%10 = arith.extf %9 : vector<32x64xf16> to vector<32x64xf32>
%11:3 = scf.for %arg0 = %c0 to %c4096 step %c128 iter_args(%arg1 = %cst_0,   %arg2 = %cst_1, %arg3 = %cst) -> (vector<32xf32>, vector<32xf32>, vector<32x64xf32>) {
  %13 = vector.transfer_read %alloc_11[%c0, %c0], %cst_4 {in_bounds = [true, true]} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<128x64xf16>
  %14 = arith.extf %13 : vector<128x64xf16> to vector<128x64xf32>
  %15 = vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %14, %cst_2 : vector<32x64xf32>, vector<128x64xf32> into vector<32x128xf32>
  %16 = vector.multi_reduction <maxf>, %15, %arg1 [1] : vector<32x128xf32> to vector<32xf32>
  %17 = vector.broadcast %16 : vector<32xf32> to vector<128x32xf32>
  %18 = vector.transpose %17, [1, 0] : vector<128x32xf32> to vector<32x128xf32>
  %19 = arith.subf %15, %18 : vector<32x128xf32>
  %20 = math.exp2 %19 : vector<32x128xf32>
  %21 = arith.subf %arg1, %16 : vector<32xf32>
  %22 = math.exp2 %21 : vector<32xf32>
  %23 = arith.mulf %22, %arg2 : vector<32xf32>
  %24 = vector.multi_reduction <add>, %20, %23 [1] : vector<32x128xf32> to vector<32xf32>
  %25 = arith.divf %cst_3, %24 : vector<32xf32>
  %26 = vector.broadcast %25 : vector<32xf32> to vector<128x32xf32>
  %27 = vector.transpose %26, [1, 0] : vector<128x32xf32> to vector<32x128xf32>
  %28 = arith.mulf %20, %27 : vector<32x128xf32>
  %29 = arith.truncf %28 : vector<32x128xf32> to vector<32x128xf16>
  %30 = vector.broadcast %23 : vector<32xf32> to vector<64x32xf32>
  %31 = vector.broadcast %25 : vector<32xf32> to vector<64x32xf32>
  %32 = vector.transpose %30, [1, 0] : vector<64x32xf32> to vector<32x64xf32>
  %33 = vector.transpose %31, [1, 0] : vector<64x32xf32> to vector<32x64xf32>
  %34 = arith.mulf %32, %33 : vector<32x64xf32>
  %35 = arith.mulf %34, %arg3 : vector<32x64xf32>
  %36 = arith.extf %29 : vector<32x128xf16> to vector<32x128xf32>
  %37 = vector.transfer_read %alloc_12[%c0, %c0], %cst_4 {in_bounds = [true, true], permutation_map = #map7} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<64x128xf16>
  %38 = arith.extf %37 : vector<64x128xf16> to vector<64x128xf32>
  %39 = vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %36, %38, %35 : vector<32x128xf32>, vector<64x128xf32> into vector<32x64xf32>
  scf.yield %16, %24, %39 : vector<32xf32>, vector<32xf32>, vector<32x64xf32>
}
%12 = arith.truncf %11#2 : vector<32x64xf32> to vector<32x64xf16>
vector.transfer_write %12, %alloc_7[%c0, %8, %c0] {in_bounds = [true, true]} : vector<32x64xf16>, memref<1x128x64xf16, #gpu.address_space<workgroup>>

Some of the questions we would like to answer are

  • Is it possible to execute all the operations above using only thread-local registers or do we need to copy data to shared memory?
  • When doing the reductions, how many lanes are participating in the reduction? How many lanes will be participating in the following operations after the reduction?

Answering those questions will require knowledge of the layouts imposed by the MMA instructions as we are attempting to do 2 MMA operations (so for example, if we attempt to use the D matrix of the first MMA as the B matrix of the second MMA, then we need to reconcile the differences in layouts between the D and B matrices imposed by the MMA instruction).

Note on lowering to LLVM

The layout information is used to do the vector distribution. As an example, vector.transfer_read

%3 = vector.transfer_read %0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>

can be lowered to a memref.load or vector.load shown below (depending on how many elements we would like to load, this being captured by the layout).

#map = affine_map<(d0, d1, d2) -> (d1 + d2 * 16)>
#map1 = affine_map<(d0, d1, d2) -> (d0 * 2)>
%3 = gpu.thread_id x
%4 = gpu.thread_id y
%5 = affine.apply #map(%3, %4, %c0)
%6 = affine.apply #map1(%3, %4, %c0)
%9 = memref.load %0[%5, %6] : memref<16x16xf16>
%10 = vector.broadcast %9 : f16 to vector<1xf16>
%11 = vector.insert_strided_slice %10, %cst_1 {offsets = [0, 0, 0, 0], strides = [1]} : vector<1xf16> into vector<1x1x4x2xf16>

Similarly, vector.contract is lowered to the appropriate mma operation.

%5 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %cst : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>

Turns into

%55 = vector.extract %cst[0, 0] : vector<1x1x2x2xf16>
%56 = vector.extract %40[0, 0] : vector<1x1x4x2xf16>
%57 = vector.extract %54[0, 0] : vector<1x1x2x2xf16>
%58 = nvgpu.mma.sync(%56, %57, %55) {mmaShape = [16, 8, 16]} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>)
-> vector<2x2xf16>
%59 = vector.insert %58, %cst [0, 0] : vector<2x2xf16> into vector<1x1x2x2xf16>

After vector distribution, the layout information is no longer required for any further downstream analysis.

Proposed Solutions

Option 1 : Adding an optional “layout” or register description attribute to the VectorType.

The idea here would be to add a layout attribute to the vector type resulting in IR like shown below.

%lhs = “arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>} : () -> (vector<16x16xf16, #layout>)

Where #layout can be defined by the user. Some examples of what this layout might look like can be found in Cute/CUTLASS, Triton and https://github.com/nod-ai/techtalks/blob/main/high_dimensional_layout_flash_attn_harsh_23.pdf.

The advantages of this approach are

  • The layout information is contained in the type so it is easy to reason about by just querying the layout
  • Since it is only required during layout reasoning, it does not need to persist beyond vector distribution (minimizing the impact of other users)
  • Works with all existing operators in arith, vector etc. dialects

The downsides of this approach are

  • Changing the vector type will require changes to all its clients and they will need to understand the semantics of the layout attribute. At the very least, they will need to reject it.

The attribute here does not need to be restricted to just be a layout attribute. It could be a general optional attribute. The only guarantee is that upstream ops will handle these attributes correctly (not drop it during transformations, etc.). This basically will then allow down-stream users to do what they want with it, and upstream does not even need to know about it. Later on, the vector distribution/layout pass would just hook into this.

Option 2: Using the tensorType for higher-level pre-distribution phasing

My understanding of this option is that since the layout is an abstracted value semantic concept, the tensor type might be more suitable for this abstraction. From the docs about difference between tensor and vector types,

Conceptual: vectors are meant to and occur in lower level dialects - often where you expect hardware to have registers of that size. Tensors model higher-level “closer to the source” abstract representation.

This case would seem to be closer to the vector use-case. But if we were to go down this path, we would use the tensor type and its encoding attribute to encode the layout. The current flow in IREE goes from Linalg on Tensors to Vectorization followed by bufferization and then finally distribution. So we would annotate the types of linalg.matmul operations with the layout information and propagate to other results and operands of neighboring operations. The problem with this approach is that once we do vectorization, we lose the layout information in the tensors as we go to vector types. So this would imply that we would not be able to do distribution on vectors which is beneficial.

The advantages of this approach are

  • Can reuse TensorType, no need to change VectorType

The disadvantages of this approach are

  • Cannot do distribution on vectors

Option 3: Using a custom mydialect.my_vector type

My understanding of this option is that we create our own custom type such as FragmentedVectorType. Over time patterns and transformations can be ported to VectorTypeOrFragmentedVectorType.

The advantages of this approach are

  • Separate type, no need to modify clients of existing VectorType

The downsides of this approach are

  • Cannot use existing vector ops with this type, will need to add new ops and/or incrementally port existing ops to new type

It’s a little bit odd to see a placeholder for an RFC :slight_smile:

I assume that this is for continuing the discussion regarding settling the approach for staging distribution of vectors for GPU. Forked from:

I believe that there were three approaches under discussion:

  • Adding an optional “layout” or “register description” attribute to VectorType
  • Using tensor for higher level pre-distribution phasing
  • Introducing a my_dialect.my_vector type of some kind and making it interop with the existing body of code

It seems like even if this RFC will be initially advocating for the first, the discussion will encompass all three.

Yes that is correct. I have updated the RFC with more information.

I see this is as deal breaker right now: having this constraint seems unmaintainable in a system where it is “optional” to have a layout (this is why MLIR API forced a Location in build methods…).

I don’t see this sustainable because most code today will just “continue to work”, but is now de factor incorrect with latent bug that will be hard to find.

Maybe we should consider a different type: VectorWithLayoutType. We can consider having the arith dialect supporting it, but it would be less likely for code to “just work” by accident I think?

I tried thinking about whether we can get away with a discardable attribute, but I haven’t converged on this yet, some raw notes below, feel free to ignore :slight_smile:

Trying to re-describe the problem in my own words (and avoiding SIMT/GPU) and go through what it’d look like:

The vector types describes a Nd dense iteration space, which at the moment is tightly coupled to the physical storage of the vector.
The computation we describe on vector are implicitly using this mapping by considering that SIMD lanes are consecutively processing the elements in this iteration space.
It starts to be an issue when different HW have different requirements in terms of SIMD lanes mapping to the computation, and layout could help decoupling the “logical organization of the vector” from the SIMD lane mapping.
Conceptually I view this as “some HW operations are actually having a builtin shuffle”.

To make it simple, let’s imagine that when implementing this op %exp = math.exp2 %input : vector<8xf32>, the HW myhw.exp2 can do this, but actually produces a result where the first two elements of the vector are swapped. Lowering the math.exp2 to my HW requires to reshuffle the output (or the input in this case):

%temp = myhw.exp2 %input : vector<8xf32>
%shuffle = vector.shuffle %temp, %temp [1, 0, 2, 3, 4, 5, 6, 7] : vector<8xf32>

Of course all of these are undesirable and the goal is to eliminate them, we could do it after materializing them, on the HW representation level, however this can quickly get impractical. Ideally we’d want to start performing these optimization before doing “instruction selection” here.

If we add a layout, it models a contract between the consumer and producer of these vectors about the “implicit” shuffle happening. It allows to start introducing these “implicit shuffles” from the HW operation into the HW-agnostic vector operations (like the ones in the arith dialect).

In terms of transformations, this is where it gets gnarly: at one given level of abstraction, transformations could drop the attribute: after all it’s just instructing the lowering about where shuffles can be eliminated. For example:

%temp = arith.exp2 %input : vector<8xf32> -> vector<8xf32, #shuffle<[1, 0, 2, 3, 4, 5, 6, 7] >>

This op in isolation can be lowered to myhw.exp2 without inserting a shuffle. Dropping the shuffle layout attribute is just fine: the legalization will insert the explicit shuffle in the IR.

How do we represent the IR post-lowering?
It’s likely that we need to preserve this information, it does not seem correct to write

%temp = myhw.exp2 %input : vector<8xf32> -> vector<8xf32>

Instead we must write:

%temp = myhw.exp2 %input : vector<8xf32> -> vector<8xf32, #shuffle<[1, 0, 2, 3, 4, 5, 6, 7] >>

So can we drop a layout? Maybe but it requires to reconciliate the layouts (possibly with shuffles) after the fact. If the producer of %input is an arith op: what if it just drops the layout during one of the transformation there?
I tried to think about modeling this with an interface, where we could have two categories of operations: “layout aware” and “layout unaware”, where the former needs to have a “reconciliation” of the input/output layout. But I’m not how this all fit together right now.

I had the same mental-model fwiw and have been trying to reconcile it because the shuffle is implicit and not strictly required at the hardware-instruction level.

However, then I talk myself out of that: let’s say that we had started with this instruction set vs growing into it. At some level, we would have put the shuffle on the type and carrying it lower towards the hardware instructions wouldn’t hurt (and may even produce some verification benefits).

I was also heading down this path. It does seem like there is an interplay between the shuffle carrying types and the shuffle canceling operations (and those that are unshuffled).

Anyway, to answer my question above. If we’d started with this rather than growing into it, I think we would have modeled it in order to avoid too big of a jump down to hardware instruction. And we would have been forced to reconcile that most operations need a way to productively ignore/propagate it, while a small handful act on it (by verifying and then dropping it during lowering).

The way to do this ergonomically with MLIR’s extant type system is eluding me a bit.

1 Like