This RFC aims to improve gather codegen support in MLIR for the commonly taken linalg -> vector -> llvm/spirv codegen path. We will first discuss gather codegen support today, compare the generated code today to what “good” code should look like, and propose a solution to bridge the gap between “today” and “good” code.
If you are unfamiliar with what a gather is, PyTorch docs do a good job describing a gather: https://pytorch.org/docs/main/generated/torch.gather.html
Motivating Example
We start with a simple example which we will call a “Paged Read”, where the IR gathers from a 2-D storage in a non-contiguous fashion on the outermost dimension. This example can be written in PyTorch as:
def paged_read(storage: torch.Tensor, ind: torch.LongTensor):
return storage[ind, :]
# Example Input
storage = torch.randn([8192, 8], dtype=torch.float16)
ind = torch.LongTensor([24, 43, 36])
read = paged_read(storage, ind)
In Linalg, this example can be written as (and is exported through Torch-MLIR in the same way):
!storage = tensor<8192x8xf16>
!ind = tensor<4xindex>
!x = tensor<4x8xf16>
#gather = {
indexing_maps = [affine_map<(page, vec) -> (page)>,
affine_map<(page, vec) -> (page, vec)>],
iterator_types = ["parallel", "parallel"]
}
func.func @main(%storage : !storage, %ind: !ind) -> !x {
%x = tensor.empty() : !x
%x_g = linalg.generic #gather
ins(%ind : !ind)
outs(%x : !x) {
^bb0(%page: index, %out: f16):
%vec = linalg.index 1 : index
%extracted = tensor.extract %storage[%page, %vec] : !storage
linalg.yield %extracted : f16
} -> !x
return %x_g : !x
}
Today, vectorizing this IR using the Linalg vectorizer and unrolling to 1D vectors (to target LLVM) produces the following code (post-bufferization, because today, in MLIR, unrolling is only implemented on buffers and not tensors as that can mess up bufferization analysis):
func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
%0 = ub.poison : vector<4x8xindex>
%1 = ub.poison : vector<8x4xindex>
%cst = arith.constant dense<8> : vector<8x4xindex>
%cst_0 = arith.constant dense<0.000000e+00> : vector<4x8xf16>
%cst_1 = arith.constant dense<true> : vector<4x8xi1>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x8xf16>
%2 = vector.load %arg1[0] : memref<4xindex>, vector<4xindex>
%3 = vector.insert %2, %1 [0] : vector<4xindex> into vector<8x4xindex>
%4 = vector.insert %2, %3 [1] : vector<4xindex> into vector<8x4xindex>
%5 = vector.insert %2, %4 [2] : vector<4xindex> into vector<8x4xindex>
%6 = vector.insert %2, %5 [3] : vector<4xindex> into vector<8x4xindex>
%7 = vector.insert %2, %6 [4] : vector<4xindex> into vector<8x4xindex>
%8 = vector.insert %2, %7 [5] : vector<4xindex> into vector<8x4xindex>
%9 = vector.insert %2, %8 [6] : vector<4xindex> into vector<8x4xindex>
%10 = vector.insert %2, %9 [7] : vector<4xindex> into vector<8x4xindex>
%11 = vector.step : vector<8xindex>
%12 = arith.muli %10, %cst : vector<8x4xindex>
%13 = vector.transpose %12, [1, 0] : vector<8x4xindex> to vector<4x8xindex>
%14 = vector.insert %11, %0 [0] : vector<8xindex> into vector<4x8xindex>
%15 = vector.insert %11, %14 [1] : vector<8xindex> into vector<4x8xindex>
%16 = vector.insert %11, %15 [2] : vector<8xindex> into vector<4x8xindex>
%17 = vector.insert %11, %16 [3] : vector<8xindex> into vector<4x8xindex>
%18 = arith.addi %17, %13 : vector<4x8xindex>
%19 = vector.gather %arg0[0, 0] [%18], %cst_1, %cst_0 : memref<8192x8xf16>, vector<4x8xindex>, vector<4x8xi1>, vector<4x8xf16> into vector<4x8xf16>
%20 = vector.extract %19[0] : vector<8xf16> from vector<4x8xf16>
vector.store %20, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
%21 = vector.extract %19[1] : vector<8xf16> from vector<4x8xf16>
vector.store %21, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
%22 = vector.extract %19[2] : vector<8xf16> from vector<4x8xf16>
vector.store %22, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
%23 = vector.extract %19[3] : vector<8xf16> from vector<4x8xf16>
vector.store %23, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
return %alloc : memref<4x8xf16>
}
Notice the vector.gather operations produced. This IR, when lowered, produces llvm.masked.gather instructions or scalar loads depending on the lowering path taken.
It is generally better to generate contiguous loads (whenever it is possible), and a better version of the lowering, would be to unroll the gathered dimension along the outer dimension, and do contiguous loads on the inner dimension. The expected “good” IR should look like:
func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x8xf16>
%0 = vector.load %arg1[0] : memref<4xindex>, vector<4xindex>
%1 = vector.extract %0[0] : index from vector<4xindex>
%3 = vector.extract %0[1] : index from vector<4xindex>
%5 = vector.extract %0[2] : index from vector<4xindex>
%7 = vector.extract %0[3] : index from vector<4xindex>
%2 = vector.load %arg0[%1, 0] : memref<8192x8xf16>, vector<8xf16>
%4 = vector.load %arg0[%3, 0] : memref<8192x8xf16>, vector<8xf16>
%6 = vector.load %arg0[%5, 0] : memref<8192x8xf16>, vector<8xf16>
%8 = vector.load %arg0[%7, 0] : memref<8192x8xf16>, vector<8xf16>
vector.store %2, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
vector.store %4, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
vector.store %6, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
vector.store %8, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
return %alloc : memref<4x8xf16>
}
Missing Abstraction Gap
To understand why we cannot generate the “good” IR today, we need to understand vector.gather more, which is our target for vectorization from Linalg today. From the docs:
The gather operation returns an n-D vector whose elements are either loaded from memory or ranked tensor, or taken from a pass-through vector, depending on the values of an n-D mask vector. If a mask bit is set, the corresponding result element is defined by the base with indices and the n-D index vector (each index is a 1-D offset on the base). Otherwise, the corresponding element is taken from the n-D pass-through vector.
Notice that the offset for each element in the index vector is a 1-D element, indexing over N-D memory. This means that the gather operation is effectively treating the N-D memory as 1-D memory and accessing it. By doing this, we are losing any dim-wise information we could exploit, for example, in the paged read example, the information that one of the dimensions is contiguous, while the other isn’t, is lost, because we linearized the dimension information.
Before vectorization, we know that a dimension being accessed is contiguous (the dimension comes from a linalg.index ), but post vectorization, we are throwing this information away. This points to an abstraction gap, caused by a missing vector dialect operation that does not lose the structure of indexing.
Preserving Structure on Indices and Memory Accesses
To build a better-structured abstraction for gathers, we look at the vectorization target for contiguous loads, vector.transfer_read and try to learn from it. vector.transfer_read preserves structure on how memory is accessed for contiguous loads. The input slice structure is simple, it’s a contiguous slice. The operation preserves how a vector is extracted from this contiguous slice.
For gathers, the input slice structure is not a contiguous slice, and instead is constructed from the indices mapping input_slice[i] = input[indices[i]]. To preserve how the input slice is constructed, we need to preserve the structure of the indices mapping.
We propose a generalization of vector.transfer_read operation, vector.transfer_gather which behaves similar to vector.transfer_read but does not assume that the slice structure is a contiguous slice. Instead, the operation encodes how the slice is gathered by preserving structure of the indices mapping.
The docs of vector.transfer_gather describe the operation as:
The `vector.transfer_gather` operation is a generalization of `vector.transfer_read` op, where the slice from which the read is performed is not guranteed to be contigious, and instead how the slice is gathered is defined explicitly in the operation.The operation can be thought of:
- A contigious slice gathered from the source as described by the operation
- A
vector.transfer_readon the contigious sliceThe operation defines
permutation_map,padding,mask,in_boundsin
the same way asvector.transfer_readdefines, but on the inferred
contigious slice.The other parameters of the operation define how the contigious slice is
gathered from the source.The
indicescontains a base to offset the source by. The dimensions of
the source which are gathered are specified as an array of indices in
gather_dims. Dimensions not specified in this array are contigious. For
example, for the following gather:slice[i, j, k] = source[i + i_offset][j][indices[i][j][k]]The operation would represent this as:
indices = %i_offset, 0, 0 gather_dims = [2]For every dimension that is gathered, the operation defines how it is
gathered. For each gathered dimension, the operation expects a vector of
indices inindex_vecsto act as a source of indices for that dimension
and an AffineMap inindex_mapsdescribing how this source of indices is
indexed. For example, for the following gather:slice[i, j, k] = source[i][indices0[i] + offset][indices1[j, k]]The indexing would be described by:
indices = 0, %offset, 0 gather_dims = [1, 2] index_vecs = %index_vec1, %index_vec2 index_maps = [ affine_map<(i, j, k) -> (i), affine_map<(i, j, k) -> (j, k) ]With these additional parameters, the operation can define a supervector
read from a non-contigious slice. For example:source: memref<8192x8x16xf32> indices0 : vector<2xindex> indices1 : vector<4x8xindex> slice[i, j, k] = source[indices0[k]][j][indices1[i, j]] vector = read(slice) : memref<8192x8x16xf32> -> vector<2x8x16xf32>Can be represented by:
%vector = vector.transfer_gather %source[0, 0, 0](%indices0, %indices1) { gather_dims = [0, 2], index_maps = [ affine_map<(i, j, k) -> (k)>, affine_map<(i, j, k) -> (i, j)> ], in_bounds = [true, true, true], permutation_map = affine_map<(i, j, k) -> (i, j, k)> } : memref<8192x8x16xf32> -> vector<2x8x16xf32>
Paged Reads with vector.transfer_gather
With this new operation, we can get the “good” IR we want when vectorizing gathers.
The Linalg vectorizer can be updated to emit vector.transfer_gather instead of emitting a vector.gather . Note that this makes the Linalg vectorizer simpler, as we don’t need to do any linearization of indices or any analysis on the indices to check if they are constant. The IR after vectorizing the gather linalg.generic looks like:
func.func @main(%arg0: tensor<8192x8xf16>, %arg1: tensor<4xindex>) -> tensor<4x8xf16> {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = vector.step : vector<8xindex>
%1 = vector.broadcast %0 : vector<8xindex> to vector<4x8xindex>
%2 = vector.transfer_read %arg1[%c0], %c0 {in_bounds = [true]} : tensor<4xindex>, vector<4xindex>
%3 = vector.broadcast %2 : vector<4xindex> to vector<8x4xindex>
%4 = vector.transpose %3, [1, 0] : vector<8x4xindex> to vector<4x8xindex>
%5 = vector.transfer_gather %arg0[%c0, %c0](%4, %1), %cst {
gather_dims = [0, 1],
index_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]}
: tensor<8192x8xf16>, vector<4x8xindex>, vector<4x8xindex>, vector<4x8xf16>
%6 = tensor.empty() : tensor<4x8xf16>
%7 = vector.transfer_write %5, %6[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf16>, tensor<4x8xf16>
return %7 : tensor<4x8xf16>
}
The Linalg vectorizer vectorizes tensor.extract to a vector.transfer_gather with the most general indices mapping possible as shown in the above IR. The vector.transfer_gather operation’s structure really shines once we enable it’s folders/canonicalizations:
func.func @main(%arg0: tensor<8192x8xf16>, %arg1: tensor<4xindex>) -> tensor<4x8xf16> {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %arg1[%c0], %c0 {in_bounds = [true]} : tensor<4xindex>, vector<4xindex>
%1 = vector.transfer_gather %arg0[%c0, %c0](%0), %cst {
gather_dims = [0],
index_maps = [affine_map<(d0, d1) -> (d0)>]}
: tensor<8192x8xf16>, vector<4xindex>, vector<4x8xf16>
%2 = tensor.empty() : tensor<4x8xf16>
%3 = vector.transfer_write %1, %2[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf16>, tensor<4x8xf16>
return %3 : tensor<4x8xf16>
}
Note that the canonical form of vector.transfer_gather was able to automatically find that the innermost dimension is contiguous.
This indices mapping can be exploited during lowering. The vector.transfer_gather operation can further be unrolled along all non-innermost dimensions (This is a standard lowering to LLVM, where we unroll all outer dimensions, since LLVM only supports 1-D vectors):
func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x8xf16>
%0 = vector.load %arg1[0] : memref<4xindex>, vector<4xindex>
%1 = vector.extract %0[0] : index from vector<4xindex>
%3 = vector.extract %0[1] : index from vector<4xindex>
%5 = vector.extract %0[2] : index from vector<4xindex>
%7 = vector.extract %0[3] : index from vector<4xindex>
%2 = vector.load %arg0[%1, 0] : memref<8192x8xf16>, vector<8xf16>
%4 = vector.load %arg0[%3, 0] : memref<8192x8xf16>, vector<8xf16>
%6 = vector.load %arg0[%5, 0] : memref<8192x8xf16>, vector<8xf16>
%8 = vector.load %arg0[%7, 0] : memref<8192x8xf16>, vector<8xf16>
vector.store %2, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
vector.store %4, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
vector.store %6, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
vector.store %8, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
return %alloc : memref<4x8xf16>
}
Note that during the lowering, as vector.transfer_gather unrolls along the outer dimension, the input slice to vector.transfer_gather becomes trivial, which canonicalizes it to vector.transfer_read, which further lowers to vector.load. This shows that the contiguous lowering for vector.transfer_gather composes with the existing lowerings.
If the innermost dimension is also gathered, the default lowering for vector.transfer_gather is to generate vector.gather operation, which can lower as they do today (to llvm.masked.gather or scalar loads).
Tradeoffs
Extending vector.gather
We could also extend vector.gather to take indices in COO format or take a vector of indices for each dimension. It would move the problem to the linalg vectorizer and other transformations to analyze the IR to find out if a dimension is contiguous (which we already knew while vectorizing, and we chose to throw away that information). https://github.com/llvm/llvm-project/pull/117939 implements such a canonicalizer, which needs to try to do something similar for 1-D index vectors, but cannot handle broadcasts or transposes, which require more complex analysis and cannot handle cases where there is no vector.step to completely fold away the offset for a dimension.
Extending vector.transfer_read
Instead of adding a new op, vector.transfer_gather we could also extend the existing vector.transfer_read op to allow capturing the structure of the input memory slice. However, the contiguous slice case is a nice restricted subset of a very commonly occurring case, and has many users. Extending it would break many assumptions in transformations and for downstream users that it does not seem worth it. The new vector.transfer_gather also composes with vector.transfer_read already as shown before, if the input slice structure is trivial (i.e. contigious), vector.transfer_gather canonicalizes to vector.transfer_read.
Make the Linalg vectorizer and transformations do the leg work
Another possibility is to make the Linalg vectorizer completely unroll the outermost non-contiguous dimensions, leaving gathers, if present, on the innermost dimension only. While this would work, this is not consistent with how the vectorizer works for other operations. The vectorizer never does unrolling, as we are throwing away structural information by unrolling (making it hard to do folding and bufferization analysis). Also, it’s not guaranteed that our target has 1D vectors, so unrolling is not always an option. Singling out vector.gather does not seem like a good solution.
Upstreaming Plan
The operation will be implemented upstream in phases. The expectation is that users will not be disrupted, and will observe no changes until the final phase.
- Add
vector.transfer_gatheroperation : [mlir][Vector] Introduce vector.transfer_gather by Groverkss · Pull Request #130113 · llvm/llvm-project · GitHub - Implement basic lowering to
vector.gatherby linearizing indices. Switch Linalg vectorizer to emitvector.transfer_gatherand immediately call the lowering tovector.gather(Effectively making no observable changes for Linalg vectorizer users). - Implement unrolling for
vector.transfer_gatheralong outer dimensions (Similar to ones implemented forvector.transfer_readandvector.transfer_write). - Move the call to lower
vector.transfer_gathertovector.gatherto a fallback for unrolling forvector.transfer_gather. This will be the PR where Linalg vectorizer users have an observable change.
People who might be interested: @banach-space @dcaballe @nicolasvasilache @ftynse @aartbik @MaheshRavishankar @qed @hanchung