[Vector] Vector distribution (large vector to small vector)

Reviving this old thread as we (@nicolasvasilache, @matthias-springer and me) have been looking at this again with a fresh start and we are taking a different approach using regions (as it was suggested by Mehdi at the time).

Here is a brief explanation of the new mechanism.

Proposed solution

In order to keep the solution incremental we want transformations into a state that is always correct and can be lowered to correct code that can be executed on the GPU.

First step, generating valid SIMT code

We want a first step that never fails and generates code that we can lower and run on a GPU. The most naive solution is to run on a single thread. We introduce a new op called warp_execute_on_lane_0 that run execution on a the lane 0 of the SIMT warp, effectively serializing the code and making it semantically correct running on a SIMT hardware.

%i0 = arith.muli %gid, %c32 : index
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<32xf32>
vector.transfer_write %l, %B[%i0] {in_bounds = [true]} : vector<32xf32>, memref<128xf32>

Can be transformed to:

vector.warp_execute_on_lane_0(%laneid : index) {
%i0 = arith.muli %gid, %c32 : index
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<32xf32>
vector.transfer_write %l, %B[%i0] {in_bounds = [true]} : vector<32xf32>, memref<128xf32>
}

Without any other transformations this can be lowered to:

%cnd = arith.cmpi eq, %laneid, %c0 : index
scf.if %cnd {
%i0 = arith.muli %gid, %c32 : index
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<32xf32>
vector.transfer_write %l, %B[%i0] {in_bounds = [true]} : vector<32xf32>, memref<128xf32>
}

This code is suboptimal but can run correctly.

Region interface

Now that we have a semantically correct version of the code we can apply transformations to avoid serializing the execution on a single lane.

The transfer_write could get distributed on the threads of the warp by doing this transformation:

%i0 = arith.muli %gid, %c32 : index
%v = vector.warp_execute_on_lane_0(%laneid : index) : vector<1xf32> {
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<32xf32>
Yield %l : vector<32xf32>
}
%i1 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%laneid)[%i0]
vector.transfer_write %l, %B[%i1] {in_bounds = [true]} : vector<1xf32>, memref<128xf32>

The arithmetic on group id is uniform and doesn’t have side effects so we can hoist it out of the critical region. Then we can hoist the transfer_write op out of the critical region by distributing it on all the threads of the warp. Since it has side effects we need to make sure no thread would write out of bound.

Since the load is still in the critical region we cannot directly use the value, instead it will be yielded by the region which means it would have to implicitly go through memory.

This code can be lowered to:

%tmp = memreg.alloc() : memref<32xf32, 3>
%cnd = arith.cmpi eq, %laneid, %c0 : index
%i0 = arith.muli %gid, %c32 : index
scf.if %cnd {
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<32xf32>
vector.store %l, %tmp[%c0] : memref<32xf32>, vector<32xf32>
}
%d = vector.load %tmp[%laneid] : memref<32xf32>, vector<32xf32>
vector.transfer_write %d, %B[%i0] {in_bounds = [true]} : vector<32xf32>, memref<128xf32>

This means a critical region can be broken down into a chain of critical regions that can be lowered on their own. Any value at the interface of those regions needs to be spilled into shared memory.

If instead we move the load out of the critical region we would have:

%i0 = arith.muli %gid, %c32 : index
%i = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%laneid)[%i0]
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<1xf32>
vector.warp_execute_on_lane_0(%laneid : index, %l : vector<1xf32>) {
^bb(%arg: vector<32xf32>)
vector.transfer_write %arg, %B[%i] {in_bounds = [true]} : vector<32xf32>, memref<128xf32>
}

This can be lowered to:

%tmp = memreg.alloc() : memref<32xf32, 3>
%cnd = arith.cmpi eq, %laneid, %c0 : index
%i0 = arith.muli %gid, %c32 : index
%i = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%laneid)[%i0]
%l = vector.transfer_read %A[%i0], %cst_0 {in_bounds = [true]} : memref<128xf32>, vector<1xf32>
vector.store %l, %tmp[%laneid] : memref<32xf32>, vector<1xf32>
scf.if %cnd {
%l1 = vector.load %tmp[%c0] : memref<32xf32>, vector<32xf32>
vector.transfer_write %l1, %B[%i] {in_bounds = [true]} : vector<32xf32>, memref<128xf32>
}

This has been experimented in iree sandbox in order to write handle warp level reduction starting from vector reduction:
https://github.com/google/iree-llvm-sandbox/blob/main/test/Integration/Dialect/VectorExt/GPU/test-warp-reduce.mlir
The lit tests shows how the different incremental transformations to get there:
https://github.com/google/iree-llvm-sandbox/blob/main/test/Dialect/vector_ext/warp.mlir

I’m going to be working on upstreaming this work and remove the old extract_map/insert_map ops that were previously used for distribution.
I sent the first patch adding the op here:
https://reviews.llvm.org/D123703