[RFC] Improving gather codegen for Vector Dialect

This RFC aims to improve gather codegen support in MLIR for the commonly taken linalg -> vector -> llvm/spirv codegen path. We will first discuss gather codegen support today, compare the generated code today to what “good” code should look like, and propose a solution to bridge the gap between “today” and “good” code.

If you are unfamiliar with what a gather is, PyTorch docs do a good job describing a gather: https://pytorch.org/docs/main/generated/torch.gather.html

Motivating Example

We start with a simple example which we will call a “Paged Read”, where the IR gathers from a 2-D storage in a non-contiguous fashion on the outermost dimension. This example can be written in PyTorch as:

def paged_read(storage: torch.Tensor, ind: torch.LongTensor):
    return storage[ind, :]

# Example Input
storage = torch.randn([8192, 8], dtype=torch.float16)
ind = torch.LongTensor([24, 43, 36])

read = paged_read(storage, ind)

In Linalg, this example can be written as (and is exported through Torch-MLIR in the same way):

!storage = tensor<8192x8xf16> 
!ind     = tensor<4xindex>
!x       = tensor<4x8xf16>

#gather = {
    indexing_maps = [affine_map<(page, vec) -> (page)>, 
                     affine_map<(page, vec) -> (page, vec)>], 
    iterator_types = ["parallel", "parallel"]
}

func.func @main(%storage : !storage, %ind: !ind) -> !x {
  %x = tensor.empty() : !x
  %x_g = linalg.generic #gather 
         ins(%ind : !ind) 
         outs(%x : !x) {
  ^bb0(%page: index, %out: f16):
    %vec   = linalg.index 1 : index
    %extracted = tensor.extract %storage[%page, %vec] : !storage
    linalg.yield %extracted : f16
  } -> !x
  return %x_g : !x
}

Today, vectorizing this IR using the Linalg vectorizer and unrolling to 1D vectors (to target LLVM) produces the following code (post-bufferization, because today, in MLIR, unrolling is only implemented on buffers and not tensors as that can mess up bufferization analysis):

func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
  %0 = ub.poison : vector<4x8xindex>
  %1 = ub.poison : vector<8x4xindex>
  %cst = arith.constant dense<8> : vector<8x4xindex>
  %cst_0 = arith.constant dense<0.000000e+00> : vector<4x8xf16>
  %cst_1 = arith.constant dense<true> : vector<4x8xi1>
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x8xf16>
  %2 = vector.load %arg1[0] : memref<4xindex>, vector<4xindex>
  %3 = vector.insert %2, %1 [0] : vector<4xindex> into vector<8x4xindex>
  %4 = vector.insert %2, %3 [1] : vector<4xindex> into vector<8x4xindex>
  %5 = vector.insert %2, %4 [2] : vector<4xindex> into vector<8x4xindex>
  %6 = vector.insert %2, %5 [3] : vector<4xindex> into vector<8x4xindex>
  %7 = vector.insert %2, %6 [4] : vector<4xindex> into vector<8x4xindex>
  %8 = vector.insert %2, %7 [5] : vector<4xindex> into vector<8x4xindex>
  %9 = vector.insert %2, %8 [6] : vector<4xindex> into vector<8x4xindex>
  %10 = vector.insert %2, %9 [7] : vector<4xindex> into vector<8x4xindex>
  %11 = vector.step : vector<8xindex>
  %12 = arith.muli %10, %cst : vector<8x4xindex>
  %13 = vector.transpose %12, [1, 0] : vector<8x4xindex> to vector<4x8xindex>
  %14 = vector.insert %11, %0 [0] : vector<8xindex> into vector<4x8xindex>
  %15 = vector.insert %11, %14 [1] : vector<8xindex> into vector<4x8xindex>
  %16 = vector.insert %11, %15 [2] : vector<8xindex> into vector<4x8xindex>
  %17 = vector.insert %11, %16 [3] : vector<8xindex> into vector<4x8xindex>
  %18 = arith.addi %17, %13 : vector<4x8xindex>
  %19 = vector.gather %arg0[0, 0] [%18], %cst_1, %cst_0 : memref<8192x8xf16>, vector<4x8xindex>, vector<4x8xi1>, vector<4x8xf16> into vector<4x8xf16>
  %20 = vector.extract %19[0] : vector<8xf16> from vector<4x8xf16>
  vector.store %20, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
  %21 = vector.extract %19[1] : vector<8xf16> from vector<4x8xf16>
  vector.store %21, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
  %22 = vector.extract %19[2] : vector<8xf16> from vector<4x8xf16>
  vector.store %22, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
  %23 = vector.extract %19[3] : vector<8xf16> from vector<4x8xf16>
  vector.store %23, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
  return %alloc : memref<4x8xf16>
}

Notice the vector.gather operations produced. This IR, when lowered, produces llvm.masked.gather instructions or scalar loads depending on the lowering path taken.

It is generally better to generate contiguous loads (whenever it is possible), and a better version of the lowering, would be to unroll the gathered dimension along the outer dimension, and do contiguous loads on the inner dimension. The expected “good” IR should look like:

func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x8xf16>
  %0 = vector.load %arg1[0] : memref<4xindex>, vector<4xindex>
  %1 = vector.extract %0[0] : index from vector<4xindex>
  %3 = vector.extract %0[1] : index from vector<4xindex>
  %5 = vector.extract %0[2] : index from vector<4xindex>
  %7 = vector.extract %0[3] : index from vector<4xindex>
  %2 = vector.load %arg0[%1, 0] : memref<8192x8xf16>, vector<8xf16>
  %4 = vector.load %arg0[%3, 0] : memref<8192x8xf16>, vector<8xf16>
  %6 = vector.load %arg0[%5, 0] : memref<8192x8xf16>, vector<8xf16>
  %8 = vector.load %arg0[%7, 0] : memref<8192x8xf16>, vector<8xf16>
  vector.store %2, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %4, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %6, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %8, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
  return %alloc : memref<4x8xf16>
}

Missing Abstraction Gap

To understand why we cannot generate the “good” IR today, we need to understand vector.gather more, which is our target for vectorization from Linalg today. From the docs:

The gather operation returns an n-D vector whose elements are either loaded from memory or ranked tensor, or taken from a pass-through vector, depending on the values of an n-D mask vector. If a mask bit is set, the corresponding result element is defined by the base with indices and the n-D index vector (each index is a 1-D offset on the base). Otherwise, the corresponding element is taken from the n-D pass-through vector.

Notice that the offset for each element in the index vector is a 1-D element, indexing over N-D memory. This means that the gather operation is effectively treating the N-D memory as 1-D memory and accessing it. By doing this, we are losing any dim-wise information we could exploit, for example, in the paged read example, the information that one of the dimensions is contiguous, while the other isn’t, is lost, because we linearized the dimension information.

Before vectorization, we know that a dimension being accessed is contiguous (the dimension comes from a linalg.index ), but post vectorization, we are throwing this information away. This points to an abstraction gap, caused by a missing vector dialect operation that does not lose the structure of indexing.

Preserving Structure on Indices and Memory Accesses

To build a better-structured abstraction for gathers, we look at the vectorization target for contiguous loads, vector.transfer_read and try to learn from it. vector.transfer_read preserves structure on how memory is accessed for contiguous loads. The input slice structure is simple, it’s a contiguous slice. The operation preserves how a vector is extracted from this contiguous slice.

For gathers, the input slice structure is not a contiguous slice, and instead is constructed from the indices mapping input_slice[i] = input[indices[i]]. To preserve how the input slice is constructed, we need to preserve the structure of the indices mapping.

We propose a generalization of vector.transfer_read operation, vector.transfer_gather which behaves similar to vector.transfer_read but does not assume that the slice structure is a contiguous slice. Instead, the operation encodes how the slice is gathered by preserving structure of the indices mapping.

The docs of vector.transfer_gather describe the operation as:

The `vector.transfer_gather` operation is a generalization of `vector.transfer_read` op, where the slice from which the read is performed is not guranteed to be contigious, and instead how the slice is gathered is defined explicitly in the operation.

The operation can be thought of:

  1. A contigious slice gathered from the source as described by the operation
  2. A vector.transfer_read on the contigious slice

The operation defines permutation_map, padding, mask, in_bounds in
the same way as vector.transfer_read defines, but on the inferred
contigious slice.

The other parameters of the operation define how the contigious slice is
gathered from the source.

The indices contains a base to offset the source by. The dimensions of
the source which are gathered are specified as an array of indices in
gather_dims. Dimensions not specified in this array are contigious. For
example, for the following gather:

slice[i, j, k] = source[i + i_offset][j][indices[i][j][k]]

The operation would represent this as:

indices = %i_offset, 0, 0
gather_dims = [2]

For every dimension that is gathered, the operation defines how it is
gathered. For each gathered dimension, the operation expects a vector of
indices in index_vecs to act as a source of indices for that dimension
and an AffineMap in index_maps describing how this source of indices is
indexed. For example, for the following gather:

slice[i, j, k] = source[i][indices0[i] + offset][indices1[j, k]]

The indexing would be described by:

indices      = 0, %offset, 0
gather_dims  = [1, 2]
index_vecs   = %index_vec1, %index_vec2
index_maps = [
  affine_map<(i, j, k) -> (i),
  affine_map<(i, j, k) -> (j, k)
]

With these additional parameters, the operation can define a supervector
read from a non-contigious slice. For example:

source: memref<8192x8x16xf32>
indices0 : vector<2xindex>
indices1 : vector<4x8xindex>

slice[i, j, k] = source[indices0[k]][j][indices1[i, j]]
vector = read(slice) : memref<8192x8x16xf32> -> vector<2x8x16xf32>

Can be represented by:

%vector = vector.transfer_gather %source[0, 0, 0](%indices0, %indices1) {
  gather_dims = [0, 2],
  index_maps = [
    affine_map<(i, j, k) -> (k)>,
    affine_map<(i, j, k) -> (i, j)>
  ],
  in_bounds = [true, true, true],
  permutation_map = affine_map<(i, j, k) -> (i, j, k)>
} : memref<8192x8x16xf32> -> vector<2x8x16xf32>

Paged Reads with vector.transfer_gather

With this new operation, we can get the “good” IR we want when vectorizing gathers.

The Linalg vectorizer can be updated to emit vector.transfer_gather instead of emitting a vector.gather . Note that this makes the Linalg vectorizer simpler, as we don’t need to do any linearization of indices or any analysis on the indices to check if they are constant. The IR after vectorizing the gather linalg.generic looks like:

func.func @main(%arg0: tensor<8192x8xf16>, %arg1: tensor<4xindex>) -> tensor<4x8xf16> {
  %cst = arith.constant 0.000000e+00 : f16
  %c0 = arith.constant 0 : index
  %0 = vector.step : vector<8xindex>
  %1 = vector.broadcast %0 : vector<8xindex> to vector<4x8xindex>
  %2 = vector.transfer_read %arg1[%c0], %c0 {in_bounds = [true]} : tensor<4xindex>, vector<4xindex>
  %3 = vector.broadcast %2 : vector<4xindex> to vector<8x4xindex>
  %4 = vector.transpose %3, [1, 0] : vector<8x4xindex> to vector<4x8xindex>

  %5 = vector.transfer_gather %arg0[%c0, %c0](%4, %1), %cst {
    gather_dims = [0, 1], 
    index_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]} 
      : tensor<8192x8xf16>, vector<4x8xindex>, vector<4x8xindex>, vector<4x8xf16>

  %6 = tensor.empty() : tensor<4x8xf16>
  %7 = vector.transfer_write %5, %6[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf16>, tensor<4x8xf16>
  return %7 : tensor<4x8xf16>
}

The Linalg vectorizer vectorizes tensor.extract to a vector.transfer_gather with the most general indices mapping possible as shown in the above IR. The vector.transfer_gather operation’s structure really shines once we enable it’s folders/canonicalizations:

func.func @main(%arg0: tensor<8192x8xf16>, %arg1: tensor<4xindex>) -> tensor<4x8xf16> {
  %cst = arith.constant 0.000000e+00 : f16
  %c0 = arith.constant 0 : index
  %0 = vector.transfer_read %arg1[%c0], %c0 {in_bounds = [true]} : tensor<4xindex>, vector<4xindex>

  %1 = vector.transfer_gather %arg0[%c0, %c0](%0), %cst {
gather_dims = [0], 
index_maps = [affine_map<(d0, d1) -> (d0)>]} 
: tensor<8192x8xf16>, vector<4xindex>, vector<4x8xf16>

  %2 = tensor.empty() : tensor<4x8xf16>
  %3 = vector.transfer_write %1, %2[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf16>, tensor<4x8xf16>
  return %3 : tensor<4x8xf16>
}

Note that the canonical form of vector.transfer_gather was able to automatically find that the innermost dimension is contiguous.

This indices mapping can be exploited during lowering. The vector.transfer_gather operation can further be unrolled along all non-innermost dimensions (This is a standard lowering to LLVM, where we unroll all outer dimensions, since LLVM only supports 1-D vectors):

func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x8xf16>
  %0 = vector.load %arg1[0] : memref<4xindex>, vector<4xindex>
  %1 = vector.extract %0[0] : index from vector<4xindex>
  %3 = vector.extract %0[1] : index from vector<4xindex>
  %5 = vector.extract %0[2] : index from vector<4xindex>
  %7 = vector.extract %0[3] : index from vector<4xindex>
  %2 = vector.load %arg0[%1, 0] : memref<8192x8xf16>, vector<8xf16>
  %4 = vector.load %arg0[%3, 0] : memref<8192x8xf16>, vector<8xf16>
  %6 = vector.load %arg0[%5, 0] : memref<8192x8xf16>, vector<8xf16>
  %8 = vector.load %arg0[%7, 0] : memref<8192x8xf16>, vector<8xf16>
  vector.store %2, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %4, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %6, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %8, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
  return %alloc : memref<4x8xf16>
}

Note that during the lowering, as vector.transfer_gather unrolls along the outer dimension, the input slice to vector.transfer_gather becomes trivial, which canonicalizes it to vector.transfer_read, which further lowers to vector.load. This shows that the contiguous lowering for vector.transfer_gather composes with the existing lowerings.

If the innermost dimension is also gathered, the default lowering for vector.transfer_gather is to generate vector.gather operation, which can lower as they do today (to llvm.masked.gather or scalar loads).

Tradeoffs

Extending vector.gather

We could also extend vector.gather to take indices in COO format or take a vector of indices for each dimension. It would move the problem to the linalg vectorizer and other transformations to analyze the IR to find out if a dimension is contiguous (which we already knew while vectorizing, and we chose to throw away that information). https://github.com/llvm/llvm-project/pull/117939 implements such a canonicalizer, which needs to try to do something similar for 1-D index vectors, but cannot handle broadcasts or transposes, which require more complex analysis and cannot handle cases where there is no vector.step to completely fold away the offset for a dimension.

Extending vector.transfer_read

Instead of adding a new op, vector.transfer_gather we could also extend the existing vector.transfer_read op to allow capturing the structure of the input memory slice. However, the contiguous slice case is a nice restricted subset of a very commonly occurring case, and has many users. Extending it would break many assumptions in transformations and for downstream users that it does not seem worth it. The new vector.transfer_gather also composes with vector.transfer_read already as shown before, if the input slice structure is trivial (i.e. contigious), vector.transfer_gather canonicalizes to vector.transfer_read.

Make the Linalg vectorizer and transformations do the leg work

Another possibility is to make the Linalg vectorizer completely unroll the outermost non-contiguous dimensions, leaving gathers, if present, on the innermost dimension only. While this would work, this is not consistent with how the vectorizer works for other operations. The vectorizer never does unrolling, as we are throwing away structural information by unrolling (making it hard to do folding and bufferization analysis). Also, it’s not guaranteed that our target has 1D vectors, so unrolling is not always an option. Singling out vector.gather does not seem like a good solution.

Upstreaming Plan

The operation will be implemented upstream in phases. The expectation is that users will not be disrupted, and will observe no changes until the final phase.

  1. Add vector.transfer_gather operation : [mlir][Vector] Introduce vector.transfer_gather by Groverkss · Pull Request #130113 · llvm/llvm-project · GitHub
  2. Implement basic lowering to vector.gather by linearizing indices. Switch Linalg vectorizer to emit vector.transfer_gather and immediately call the lowering to vector.gather (Effectively making no observable changes for Linalg vectorizer users).
  3. Implement unrolling for vector.transfer_gather along outer dimensions (Similar to ones implemented for vector.transfer_read and vector.transfer_write).
  4. Move the call to lower vector.transfer_gather to vector.gather to a fallback for unrolling for vector.transfer_gather. This will be the PR where Linalg vectorizer users have an observable change.

People who might be interested: @banach-space @dcaballe @nicolasvasilache @ftynse @aartbik @MaheshRavishankar @qed @hanchung

3 Likes

Hey @Groverkss , thank you for working on this and for sharing your ideas with us!

In the interest of full transparency, I previously reviewed early drafts of this RFC, and some of my lower-level comments have already been incorporated. Here, I’d like to focus on my high-level thoughts.

As the author of the current Linalg logic to vectorize tensor.extract, I am well aware that the current abstractions are not suitable/powerful enough to generate optimal code for the Linalg example that you shared. However, I wonder if they should?

Are we really dealing with a Gather operation?

As an initial point, I want to highlight that neither your PyTorch nor Linalg examples are truly “gather” accesses. That’s evident from the fact that your PyTorch motivating example does not use torch.gather.

Instead, what I see is a “paged read” or a “distributed read”. (I’m also mindful of your presentation on MLIR Vector Distribution.)

What are we dealing with, then?

To better understand the underlying problem, let’s consider three possible scenarios for tensor.extract accesses, using a 7×6 matrix as an example.

  • . → Regular elements.
  • X → Elements being read.

In all cases we are reading 3 rows.

EXAMPLE 1: “Block Read” - reads 3 consecutive rows

.  .  .  .  .  .
.  .  .  .  .  .
X  X  X  X  X  X
X  X  X  X  X  X
X  X  X  X  X  X
.  .  .  .  .  .
.  .  .  .  .  .

EXAMPLE 2: “Strided read” - reads 3 rows, every row is a “stride” apart from the previous row (stride = 1 row):

.  .  .  .  .  .
.  .  .  .  .  .
X  X  X  X  X  X
.  .  .  .  .  .
X  X  X  X  X  X
.  .  .  .  .  .
X  X  X  X  X  X

EXAMPLE 3: “Gather load”

.  X  .  .  .  X
.  .  .  X  .  .
X  .  .  .  X  .
.  .  X  .  .  .
X  .  .  X  .  .
.  .  .  .  X  .
.  X  .  .  .  .
  • EXAMPLE 1 can be modelled using vector.load,
  • EXAMPLE 2 can be modelled using vector.extract_strided_slice,
  • EXAMPLE 3 can be modelled using vector.gather.

Finally, this is the example from your RFC.

EXAMPLE 4: “Distributed/paged read” - read 3 “random” rows

X  X  X  X  X  X
.  .  .  .  .  .
X  X  X  X  X  X
X  X  X  X  X  X
.  .  .  .  .  .
.  .  .  .  .  .
.  .  .  .  .  .

As you observe in your post, there is no good (single-op) representation for EXAMPLE 4 in Vector. However, should there be? Like many things in Vector, this can be modelled by composing other Vector Ops.

Btw, I briefly touched on how the vectorizer deals with different access patterns in my LLVM Dev '24 presenation (starts @ ~7 min.).

Propagating Complexities from Higher Levels to Vector

The complexity in your example originates from the way the program is represented at the PyTorch level. It doesn’t seem right to push that complexity all the way down to Vector and expect it to introduce new abstractions to handle it.

We should not forget that TorchIR, Linalg and Tensor exist along the way. Would it not make more sense to address these complexities at those higher levels of abstraction instead?

Explosion of “Read”/“Write” Ops in Vector

Currently, there are 41 operations in Vector (excluding deprecated vector.{extract|insert}element). Out of these, 16 are some form of “read” or “write” operations. Vector already provides a broad and flexible set of operations for data movement. IMHO, adding even more complex (*) “read”/“write” operations to Vector does call for careful consideration.

Btw, I’m working on a new taxonomy for these operations and have gathered them here:

(*) What’s being proposed here is much more complex than any of the Ops that we have today.

How About Extending Tensor + MemRef Instead?

This RFC proposed a new complex boundary op between Tensor/MemRef and Vector.
But this begs the question, why not extend Tensor and/or MemRef instead?

Some of the existing read/write ops in Vector complement similar ops in MemRef and Tensor. Strided accesses are my favorite example. However, overall, there’s many more “read”/“write” Ops in Vector than there is in MemRef/Tensor. Perhaps what’s missing is e.g. tensor.view? Or, how about extending memref.subview?

This is are genuine questions. We’re dealing with a fundamental challenge, and we seem to have jumped directly to the conclusion that the solution should go into Vector. I am not convinced.

Limitations of the Linalg Vectorizer

Your linalg.generic example is rather complex, and since it doesn’t originate from torch.gather, it feels somewhat artificial. Why note decompose it? The vectorizer is optimized towards a “common denominator” and your linalg.generic is quite far off from that.

Crucially, the Linalg vectorizer should be not expected to handle every complex input optimally. In fact, there are many Ops that it does not vectorize, e.g. 2D and 3D convolutions. Instead, such operations are decomposed before reaching the vectorizer (“pre-processing”).

Similarly, the vectorized code produced by the Linalg vectorizer is intentionally “basic”/“crude”. Many transformations run after vectorization to optimize things further. For example, the vectorizer inserts masks conservatively, and subsequent transformations remove them where possible — and they do a great job!

A Simpler Alternative

Instead of introducing a new op, could you split your linalg.generic into three separate linalg.generics, each reading one row at a time?

func.func @func(%idx : index, %3: tensor<8192x8xf16>) -> tensor<1x8xf16> {
  %c0_i64 = arith.constant 0 : i64
  %5 = tensor.empty() : tensor<1x8xf16>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
    outs(%5 : tensor<1x8xf16>) {
      ^bb0(%out: f16):
        %11 = linalg.index 1 : index
        %12 = linalg.index 0 : index
        %13 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 8)>()[%11, %12]
        %extracted = tensor.extract %3[%idx, %13] : tensor<8192x8xf16>
        linalg.yield %extracted : f16
    } -> tensor<1x8xf16>
  return %6 : tensor<1x8xf16>
}


module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [1, 8] {vectorize_nd_extract} : !transform.any_op
     transform.yield
   }
}

The vectorizer happily vectorizes this as vector.transfer_read.

Final Points

I’ve tried my best to enumerate my high-level concerns. All in all, I think that we should refrain from adding vector.transfer_{gather|scatter} to Vector. The challenges described here can be addressed through existing mechanisms. And if new abstractions are really needed, it’s not clear to me that Vector is the right place for expansion.

If people disagree, then perhaps it’s time to revisit the Vector charter. Just to be clear: I don’t necessarily see “revisiting the charter” as an exercise in expanding Vector’s scope.

Lastly, I wanted to make sure I provide an alternative approach to your motivating example. Hopefully, my suggestion unblocks you.

-Andrzej

As an initial point, I want to highlight that neither your PyTorch nor Linalg examples are truly “gather” accesses. That’s evident from the fact that your PyTorch motivating example does not use torch.gather.

Both examples fit the definition of a gather as defined by PyTorch (I’m using PyTorch as a good input to think about because it’s a very widely used framework). The PyTorch example I wrote is a short form way of writing a gather. Here is the same example using torch.gather:

def paged_read(storage: torch.Tensor, ind: torch.Tensor):
    ind = ind.unsqueeze(1).broadcast_to(ind.shape[0], storage.shape[1])
    return torch.gather(storage, dim=0, index=ind)

# Example Input
storage = torch.randn([8192, 8], dtype=torch.float16)
ind = torch.LongTensor([24, 43, 36])

read = paged_read(storage, ind)

This produces the same code and the same result as both the PyTorch and the Linalg example in this RFC. This is clearly a gather, because it is taking a list of indices only known at runtime which are not contigious or ordered.

Instead, what I see is a “paged read” or a “distributed read”. (I’m also mindful of your presentation on MLIR Vector Distribution. )

EXAMPLE 4: “Distributed/paged read” - read 3 “random” rows

This RFC is unrelated to Vector Distribution. Let’s not mix these things. The example I have provided throught the RFC is using CPU codegen.

As you observe in your post, there is no good (single-op) representation for EXAMPLE 4 in Vector . However, should there be? Like many things in Vector , this can be modelled by composing other Vector Ops.

I don’t think you can compose existing Vector dialect operations to model this. Could you give an example please? In this RFC, I did my best to present why there is no good form to represent this today. I’m up for reusing existing abstractions, but I don’t see a good one. Note that it’s not only me that noticed this, @nicolasvasilache noticed a similar thing about vector dialect gather representation: Lowering of scatter operations - #8 by nicolasvasilache (Nicolas mentions extending vector.gather, but I have noted in the RFC, why it’s better to improve the multi dimensional gather abstraction even more).

@Hardcode84 Also noticed a similar problem and decided to emit the ops in an unrolled form for the compiler he is working on: [TKW] Unroll gathers/scatters by Hardcode84 · Pull Request #541 · iree-org/iree-turbine · GitHub

It doesn’t seem right to push that complexity all the way down to Vector and expect it to introduce new abstractions to handle it.

Instead of introducing a new op, could you split your linalg.generic into three separate linalg.generics, each reading one row at a time?

The suggestion you mention does not work for dynamic cases. We should always expect dynamically shaped inputs and the examples I’m currently working on are dynamically shaped on the gather dimension. To do these cases efficiently, you want to mask the operation on the gather dimension. You can also see why lowering to loops on the gathered dimension at Linalg level is also problematic, if you do that, you cannot do masking easily.

I don’t think this RFC is trying to introduce any new paradigm to the Vector dialect. vector.transfer_read/write are very widely used operations and are a very good abstraction in my experience. This RFC is introducing an abstraction at the same level, but removing the restriction that the slice read is contigious. This operation should belong exactly the same place as vector.transfer_read/write.

Currently, there are 41 operations in Vector (excluding deprecated vector.{extract|insert}element). Out of these, 16 are some form of “read” or “write” operations. Vector already provides a broad and flexible set of operations for data movement. IMHO, adding even more complex (*) “read”/“write” operations to Vector does call for careful consideration.

I agree it calls for careful consideration, which is why I attempted by best at describing what we cannot represent using Vector dialect today and in this reply, why simply unrolling at linalg.generic level does not work. I’m happy to provide more prototyping, more examples, etc. if required to fulfill any more reasoning asked.

Some of the existing read/write ops in Vector complement similar ops in MemRef and Tensor. Strided accesses are my favorite example. However, overall, there’s many more “read”/“write” Ops in Vector than there is in MemRef/Tensor. Perhaps what’s missing is e.g. tensor.view? Or, how about extending memref.subview?

My opinion is that this op should exist wherever vector.transfer_read/write exist. It’s at the same abstraction as other transfer ops.

Your linalg.generic example is rather complex, and since it doesn’t originate from torch.gather, it feels somewhat artificial. Why note decompose it? The vectorizer is optimized towards a “common denominator” and your linalg.generic is quite far off from that.

Could you give an example of a good way to write the above linalg.generic instead? I’m happy to reconsider and use existing ops, if given an example. I have mentioned above why the simpler alternative doesn’t work (dynamic shapes).

I’ve tried my best to enumerate my high-level concerns. All in all, I think that we should refrain from adding vector.transfer_{gather|scatter} to Vector. The challenges described here can be addressed through existing mechanisms. And if new abstractions are really needed, it’s not clear to me that Vector is the right place for expansion.

If people disagree, then perhaps it’s time to revisit the Vector charter. Just to be clear: I don’t necessarily see “revisiting the charter” as an exercise in expanding Vector’s scope.

I really don’t see a good example of addressing these challenges through existing mechanims yet. These are not new abstractions which are different from the existing ones. They are at the same level as vector.transfer_read/write (The op even implements VectorTransferInterface which already exists).

my 2c here:

  1. In our compiler we are not using linalg pipeline, instead generating vector ops directly, with indices computed from user-specified sympy expressions.
  2. Vector dialect is missing a proper abstraction for multi-dimensional gathers/scatters. As was mentioned, while input memref and result vector type can be multidimensional, the offset vector provides only a single value for each result. Current spec doesn’t even specify how to interpret it for N-D memrefs. The current lowering interpret it as linear offset into memref data and it’s the fact we are (ab)using.
  3. Treating it as linear offsets breaks lowering layering as to properly compute such offsets you need to know memref layout and explicitly extract memref strides.
  4. The PR @Groverkss had metioned is a mess, as we are still have both gather and non-gather code paths, and for non gather we need to manually linearize memref to be able to consume gather-style linearized offsets

I’m generally +1 to the current proposal. Although, vector.tranfer_xyz are quite a high level ops, it would be nice to have a low-level counterpart too, which doesn’t need explicit linearization (e.g. vector.gather consuming list of offset vectors, one for each memref dim).

What hardware unit(s) does this correspond with?

None, it will be decomposed later, but we still need something to represent gathers/scatters on N-D memrefs. Current vector.gather/scatter semantics is not even properly specified (or I wasn’t able to find it) for N-D memrefs and just derived from the lowering.

upd: under “decomposed” I mean, translated to 1D gather during llvm/whatever lowering.

There are GPUs with 2d memory access instructions FYI: SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.asciidoc at main · KhronosGroup/SPIRV-Registry · GitHub

Hey Kunwar, thanks a lot for bringing this up! Yeah, gathers and “contiguous loads in disguise” have been a major pain since we introduced vectorization support for tensor.extract. I agree with the motivation and like how you are approaching the problem. However, I’d suggest a slight change in focus. Here are my thoughts:

  • We know that the vector.gather operation is incomplete, especially when it comes to multi-dim gathers. We have to make sure that the solution doesn’t lead to a parallel vector gather representation that leaves the existing “disfunctional” one hanging in there. That wouldn’t be a great outcome. Instead, I think the approach should focus on filling the gaps in multi-dimensional semantics and adding whatever missing functionality vector.gather is missing today.

  • When bringing an operation to the multi-dimensional domain it’s expected that certain traits need to be encoded per dimension. We do that for in_bounds, for example, among many other cases. For the gather case, it shouldn’t be different! We should aim to encode whether any dimension is contiguous or not, not only the innermost one. Knowing that the outer dimensions of a memory access are contiguous, even when the innermost is not, can enable some optimizations. This information could be encoded per dimension by adding new attributes (e.g., [contiguous, random, strided]) to vector.gather.

  • I would leave any kind of indexing/permutation maps outside of the picture to simplify an already highly complex operation. I wouldn’t scope this within the transfer op family either. Transfer ops have been great at abstracting away all the details of memory loads and stores for cases where we really don’t care much about the memory access pattern or how the data is loaded. However, they have been a struggle when we need to reason about that information because they concentrate too much information. For the gather case, we are looking at encoding specifics of the memory access itself, which doesn’t align the aforementioned main goals of transfer ops.

I hope this makes sense! I’m open to discussing this further in a higher bandwidth venue (call?, Euro LLVM?, …) or whatever works for you!

Thanks,
Diego

Thanks for you reply, Kunwar!

I might have over-indexed on your “motivating” example.

Let me take a step back. In my earlier post, I shared examples to demonstrate my mental model and my understanding of the problem we’re trying to solve. My EXAMPLE 4 is what I saw in your PyTorch example (let’s ignore naming for now):

X  X  X  X  X  X
.  .  .  .  .  .
X  X  X  X  X  X
X  X  X  X  X  X
.  .  .  .  .  .
.  .  .  .  .  .
.  .  .  .  .  .

And that should be modelled like this at the Vector level (*) :

func.func @main(%arg0: memref<8192x8xf16>, %arg1: memref<4xindex>) -> memref<4x8xf16> {
  // (...)
  %2 = vector.load %arg0[%1, 0] : memref<8192x8xf16>, vector<8xf16>
  %4 = vector.load %arg0[%3, 0] : memref<8192x8xf16>, vector<8xf16>
  %6 = vector.load %arg0[%5, 0] : memref<8192x8xf16>, vector<8xf16>
  %8 = vector.load %arg0[%7, 0] : memref<8192x8xf16>, vector<8xf16>
  vector.store %2, %alloc[0, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %4, %alloc[1, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %6, %alloc[2, 0] : memref<4x8xf16>, vector<8xf16>
  vector.store %8, %alloc[3, 0] : memref<4x8xf16>, vector<8xf16>
  // (...)
}

You should be able to get here (i.e. to contiguous loads) by:

  1. Unrolling linalg.generic that’s produced by torch-mlir,
  2. Vectorizing each linalg.generic from Step 1 separately.

OK, now I see. To me, there’s framework-level “gather” and Vector/“hardware” level gather, which are two different things. IMO, Vector should focus on modelling the lower-level view.

That appears to be a very recent spec (2025?) rather than actual hardware.

This might be an unpopular opinion, but for proposals like this, I would hope there’s actual hardware we can test in-tree to demonstrate the need for new Ops (within reason - there are exceptions like vector.shape_cast and vector.type_cast).

I would take a step back and question whether tensor.extract + linalg.generic is the right abstraction here to begin with. We seem to insist on this model:

%out = linalg.generic () {
  tensor.extract
} -> tensor<NxMxf32>

However, I feel that we should explore replacing the tensor.extract + linalg.generic model with (for certain cases):

  tensor.gather () -> tensor<NxMxf32>

This way, we avoid “recovering” access patterns later, since they would be explicitly encoded in the representation.

Following on from that - do we need to take the output from torch-mlir at face value? Surely that can be modified?

I hope this clarifies some of the points I made earlier. All in all, I do agree that we are missing an abstraction. However, it’s not obvious to me where that abstraction belongs. This is a tricky problem, so kudos for taking this on @Groverkss !

-Andrzej

(*) Btw, your PyTorch example loads 3 rows, whereas your Linalg/Vector example loads 4 rows. That’s beside the point - just explaining the source of inconsistency in my examples.

1 Like

Thanks for the replies everyone! I need to do a more comprehensive reply, but just a quick reply about the fact that the recommended existing solution does not work.

As mentioned previously, this is not possible. The indices vector can be dynamic which you cannot unroll. You want to mask the indices vector.

I want to point this out before replying more, to make sure we agree there is not good solution that exists today to get the “good” code

OK, I misunderstood that part in your earlier reply. So the number of “rows” (*) to read can be dynamic. How about:

scf.for
  linalg.generic

? Again, just sharing my mental model. To me, with “dynamism”, this there’s always going to be a loop hidden somewhere. So it’s a question of where to materialise it. Perhaps “masking” helps, that’s not clear to me.

-Andrzej

(*) Using “rows” for simplicity, but I appreciate it could be any dim.

  • We know that the vector.gather operation is incomplete, especially when it comes to multi-dim gathers. We have to make sure that the solution doesn’t lead to a parallel vector gather representation that leaves the existing “disfunctional” one hanging in there. That wouldn’t be a great outcome. Instead, I think the approach should focus on filling the gaps in multi-dimensional semantics and adding whatever missing functionality vector.gather is missing today.

I think that’s a good idea. We shouldn’t leave vector.gather hanging. I’ll start by improving vector.gather and vector.scatter support and then we can move on to discussing if we need vector.transfer_gather / vector.transfer_scatter.

  • When bringing an operation to the multi-dimensional domain it’s expected that certain traits need to be encoded per dimension. We do that for in_bounds, for example, among many other cases. For the gather case, it shouldn’t be different! We should aim to encode whether any dimension is contiguous or not, not only the innermost one. Knowing that the outer dimensions of a memory access are contiguous, even when the innermost is not, can enable some optimizations. This information could be encoded per dimension by adding new attributes (e.g., [contiguous, random, strided]) to vector.gather.

Good idea, I’ll start by adding something similare to vector.gather / vector.scatter.

I hope this makes sense! I’m open to discussing this further in a higher bandwidth venue (call?, Euro LLVM?, …) or whatever works for you!

Let’s have an open call / meet at EuroLLVM after I’ve improved vector.gather / vector.scatter support for multi dimensional indices. I doubt there is any opposition to that and it’s an overall improvement, over which we can build on.

I would leave any kind of indexing/permutation maps outside of the picture to simplify an already highly complex operation. I wouldn’t scope this within the transfer op family either. Transfer ops have been great at abstracting away all the details of memory loads and stores for cases where we really don’t care much about the memory access pattern or how the data is loaded. However, they have been a struggle when we need to reason about that information because they concentrate too much information. For the gather case, we are looking at encoding specifics of the memory access itself, which doesn’t align the aforementioned main goals of transfer ops.

I have a different opinion on this, but let’s discuss this later when we have the call about vector.transfer_gather.

Thanks for the reply Diego, these are really useful techincal points and they make sense and I can make progress on them.

I have started sending PRs to improve multi dimensional indices support for vector.gather and vector.scatter.

The first set of patches make vector.gather and vector.gather consistent with each other:

  1. [mlir][vector] Decouple unrolling gather and gather to llvm lowering by Groverkss · Pull Request #132206 · llvm/llvm-project · GitHub
  2. [mlir][vector] Allow multi dim vectors in vector.scatter by Groverkss · Pull Request #132217 · llvm/llvm-project · GitHub
  3. [mlir][vector] Allow lowering multi-dim scatters to LLVM by Groverkss · Pull Request #132227 · llvm/llvm-project · GitHub

Note that each patch is dependent on the previous one and later patches will continue to be dependent in a similar way.

3 Likes