Preliminaries
Matrix multiply accumulation instructions on GPUs have constraints on which threads can hold which matrix fragments of the A, B, C and D matrices in their registers. For example, AMD GPUs require the data to be arranged in the following format (AMD matrix cores (amd-lab-notes) - AMD GPUOpen). Other hardware vendors have similar constraints. In the following we will be referring to these constraints as a ‘layout’ but the concept is actually quite low level and is more akin to the configuration of certain register blocks needed to satisfy ISA constraints (as opposed to ‘layout’ as more commonly used higher in the stack)
The layout will not only stem from these MMA instructions. One can also impose restrictions on certain operators that can be represented as layout. As an example, if we wanted to load 128 bits of data from global to shared memory then we could impose this constraint as a layout on the vector.transfer_read operator.
Problem Statement
We would like to represent this layout explicitly in the IR so we can reason about the layouts of the operands and results of such operators as well as any neighboring operators that could reuse the data present in the registers. For example consider the IR shown below for flash attention.
%9 = vector.transfer_read %alloc[%c0, %8, %c0], %cst_4 {in_bounds = [true, true]} : memref<1x128x64xf16, #gpu.address_space<workgroup>>, vector<32x64xf16>
%10 = arith.extf %9 : vector<32x64xf16> to vector<32x64xf32>
%11:3 = scf.for %arg0 = %c0 to %c4096 step %c128 iter_args(%arg1 = %cst_0, %arg2 = %cst_1, %arg3 = %cst) -> (vector<32xf32>, vector<32xf32>, vector<32x64xf32>) {
%13 = vector.transfer_read %alloc_11[%c0, %c0], %cst_4 {in_bounds = [true, true]} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<128x64xf16>
%14 = arith.extf %13 : vector<128x64xf16> to vector<128x64xf32>
%15 = vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %14, %cst_2 : vector<32x64xf32>, vector<128x64xf32> into vector<32x128xf32>
%16 = vector.multi_reduction <maxf>, %15, %arg1 [1] : vector<32x128xf32> to vector<32xf32>
%17 = vector.broadcast %16 : vector<32xf32> to vector<128x32xf32>
%18 = vector.transpose %17, [1, 0] : vector<128x32xf32> to vector<32x128xf32>
%19 = arith.subf %15, %18 : vector<32x128xf32>
%20 = math.exp2 %19 : vector<32x128xf32>
%21 = arith.subf %arg1, %16 : vector<32xf32>
%22 = math.exp2 %21 : vector<32xf32>
%23 = arith.mulf %22, %arg2 : vector<32xf32>
%24 = vector.multi_reduction <add>, %20, %23 [1] : vector<32x128xf32> to vector<32xf32>
%25 = arith.divf %cst_3, %24 : vector<32xf32>
%26 = vector.broadcast %25 : vector<32xf32> to vector<128x32xf32>
%27 = vector.transpose %26, [1, 0] : vector<128x32xf32> to vector<32x128xf32>
%28 = arith.mulf %20, %27 : vector<32x128xf32>
%29 = arith.truncf %28 : vector<32x128xf32> to vector<32x128xf16>
%30 = vector.broadcast %23 : vector<32xf32> to vector<64x32xf32>
%31 = vector.broadcast %25 : vector<32xf32> to vector<64x32xf32>
%32 = vector.transpose %30, [1, 0] : vector<64x32xf32> to vector<32x64xf32>
%33 = vector.transpose %31, [1, 0] : vector<64x32xf32> to vector<32x64xf32>
%34 = arith.mulf %32, %33 : vector<32x64xf32>
%35 = arith.mulf %34, %arg3 : vector<32x64xf32>
%36 = arith.extf %29 : vector<32x128xf16> to vector<32x128xf32>
%37 = vector.transfer_read %alloc_12[%c0, %c0], %cst_4 {in_bounds = [true, true], permutation_map = #map7} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<64x128xf16>
%38 = arith.extf %37 : vector<64x128xf16> to vector<64x128xf32>
%39 = vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %36, %38, %35 : vector<32x128xf32>, vector<64x128xf32> into vector<32x64xf32>
scf.yield %16, %24, %39 : vector<32xf32>, vector<32xf32>, vector<32x64xf32>
}
%12 = arith.truncf %11#2 : vector<32x64xf32> to vector<32x64xf16>
vector.transfer_write %12, %alloc_7[%c0, %8, %c0] {in_bounds = [true, true]} : vector<32x64xf16>, memref<1x128x64xf16, #gpu.address_space<workgroup>>
Some of the questions we would like to answer are
- Is it possible to execute all the operations above using only thread-local registers or do we need to copy data to shared memory?
- When doing the reductions, how many lanes are participating in the reduction? How many lanes will be participating in the following operations after the reduction?
Answering those questions will require knowledge of the layouts imposed by the MMA instructions as we are attempting to do 2 MMA operations (so for example, if we attempt to use the D matrix of the first MMA as the B matrix of the second MMA, then we need to reconcile the differences in layouts between the D and B matrices imposed by the MMA instruction).
Note on lowering to LLVM
The layout information is used to do the vector distribution. As an example, vector.transfer_read
%3 = vector.transfer_read %0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
can be lowered to a memref.load or vector.load shown below (depending on how many elements we would like to load, this being captured by the layout).
#map = affine_map<(d0, d1, d2) -> (d1 + d2 * 16)>
#map1 = affine_map<(d0, d1, d2) -> (d0 * 2)>
%3 = gpu.thread_id x
%4 = gpu.thread_id y
%5 = affine.apply #map(%3, %4, %c0)
%6 = affine.apply #map1(%3, %4, %c0)
%9 = memref.load %0[%5, %6] : memref<16x16xf16>
%10 = vector.broadcast %9 : f16 to vector<1xf16>
%11 = vector.insert_strided_slice %10, %cst_1 {offsets = [0, 0, 0, 0], strides = [1]} : vector<1xf16> into vector<1x1x4x2xf16>
Similarly, vector.contract is lowered to the appropriate mma operation.
%5 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %cst : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
Turns into
%55 = vector.extract %cst[0, 0] : vector<1x1x2x2xf16>
%56 = vector.extract %40[0, 0] : vector<1x1x4x2xf16>
%57 = vector.extract %54[0, 0] : vector<1x1x2x2xf16>
%58 = nvgpu.mma.sync(%56, %57, %55) {mmaShape = [16, 8, 16]} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>)
-> vector<2x2xf16>
%59 = vector.insert %58, %cst [0, 0] : vector<2x2xf16> into vector<1x1x2x2xf16>
After vector distribution, the layout information is no longer required for any further downstream analysis.
Proposed Solutions
Option 1 : Adding an optional “layout” or register description attribute to the VectorType.
The idea here would be to add a layout attribute to the vector type resulting in IR like shown below.
%lhs = “arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>} : () -> (vector<16x16xf16, #layout>)
Where #layout can be defined by the user. Some examples of what this layout might look like can be found in Cute/CUTLASS, Triton and https://github.com/nod-ai/techtalks/blob/main/high_dimensional_layout_flash_attn_harsh_23.pdf.
The advantages of this approach are
- The layout information is contained in the type so it is easy to reason about by just querying the layout
- Since it is only required during layout reasoning, it does not need to persist beyond vector distribution (minimizing the impact of other users)
- Works with all existing operators in arith, vector etc. dialects
The downsides of this approach are
- Changing the vector type will require changes to all its clients and they will need to understand the semantics of the layout attribute. At the very least, they will need to reject it.
The attribute here does not need to be restricted to just be a layout attribute. It could be a general optional attribute. The only guarantee is that upstream ops will handle these attributes correctly (not drop it during transformations, etc.). This basically will then allow down-stream users to do what they want with it, and upstream does not even need to know about it. Later on, the vector distribution/layout pass would just hook into this.
Option 2: Using the tensorType for higher-level pre-distribution phasing
My understanding of this option is that since the layout is an abstracted value semantic concept, the tensor type might be more suitable for this abstraction. From the docs about difference between tensor and vector types,
Conceptual: vectors are meant to and occur in lower level dialects - often where you expect hardware to have registers of that size. Tensors model higher-level “closer to the source” abstract representation.
This case would seem to be closer to the vector use-case. But if we were to go down this path, we would use the tensor type and its encoding attribute to encode the layout. The current flow in IREE goes from Linalg on Tensors to Vectorization followed by bufferization and then finally distribution. So we would annotate the types of linalg.matmul operations with the layout information and propagate to other results and operands of neighboring operations. The problem with this approach is that once we do vectorization, we lose the layout information in the tensors as we go to vector types. So this would imply that we would not be able to do distribution on vectors which is beneficial.
The advantages of this approach are
- Can reuse TensorType, no need to change VectorType
The disadvantages of this approach are
- Cannot do distribution on vectors
Option 3: Using a custom mydialect.my_vector type
My understanding of this option is that we create our own custom type such as FragmentedVectorType. Over time patterns and transformations can be ported to VectorTypeOrFragmentedVectorType.
The advantages of this approach are
- Separate type, no need to modify clients of existing VectorType
The downsides of this approach are
- Cannot use existing vector ops with this type, will need to add new ops and/or incrementally port existing ops to new type