[RFC] Adding Gather, Scatter Ops

At this time, MLIR is still missing a higher-order gather, scatter abstraction. This severely limits the ability to represent and reason about such primitives.

This RFC proposes to add such abstractions to the tensor dialect, driven by first-principle considerations, and that compose with other abstractions for tensors and buffers. These operations are the second embodiment of a notional SubsetOpInterface (the first embodiment are tensor.extract_slice, tensor.insert_slice).

This is related to the larger discussion on dense and sparse structured codegen.

By focusing on first-principles, clean semantics and composition with existing abstractions, we want to avoid a slippery slope of conflation and complexity that would make the op extremely hard to in practice (e.g. the HLO gather and scatter and select_and_scatter. The related Gather HLO is Complex doc also contains relevant background information).

GatherOp

The gather operation extracts a subset of the elements from an input tensor at the given indices. In its fine-grain form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format *minus* the payload, a.k.a list of coordinates).

// For each 1x2 triple of coordinates in %indices, extract the element (i.e. 0-D
// subset) at the coordinates triple in %input.
// This corresponds to the COO format *minus* the payload (a.k.a list of coordinates), 
// enumerating triples.
//
// Note: In the IR below, the indexing part of the indices tensor is isolated by a 
// whitespace (e.g. we write tensor<1x2x 3xindex> instead of tensor<1x2x3xindex> to 
// better isolate the coordinate part "3xindex").
//
// This corresponds to an implicit gather_along_dimension=(0, 1, 2)
%out = tensor.gather %input[%indices] : 
  tensor<4x4x4xf32>[tensor<1x2x 3xindex>] -> tensor<1x2xf32>

A slice variant is provided that allows specifying whole slices of the input tensor, to allow rectangular subsets and avoid obvious performance pits where possible:

// For each 5x6 singleton of coordinates in %indices, extract the 2-D slice [*, n:n+1,*]
// at the coordinates singleton in %input.
%out = tensor.gather %input[%indices] gather_along_dimension(1) : 
  tensor<2x3x4xf32>[tensor<5x6x 1xindex>] -> tensor<5x6x2x1x4xf32>
// Could also be spelled as:
//   tensor<2x3x4xf32>[tensor<5x6x 1xindex>] -> tensor<5x6x2x1x4xf32>

Only full slices are supported, if one desires partial slices one should compose with other tensor ops such as tensor.extract_slice.

An optional unique unit attribute may be specified to indicate that the coordinates are statically guaranteed to be unique at runtime. Incorrectly setting the unique attribute when the coordinates are not truly unique is undefined behavior.

ScatterOp

The scatter operation is the symmetrical operation to gather. It insert a subset of the elements from an input tensor into a destination tensor at the given indices. In its fine-grain form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format *minus* the payload, a.k.a list of coordinates).

// For each 1x2 triple of coordinates in %indices, insert the 
// element (i.e. 0-D subset) at the coordinates triple in %dest.
// This corresponds to implicit coordinates = [0, 1, 2]
//
%out = scatter %input into %dest[%indices] : 
  tensor<1x2xf32> into tensor<4x4x4xf32>[tensor<1x2x 3xindex>] 
    -> tensor<4x4x4xf32>

A slice variant is provided that allows specifying whole slices of the input tensor, to allow rectangular subsets and avoid obvious performance pits where possible:

// For each 5x6 singleton of coordinates in %indices, insert the 2-D
// slice [:, 1, :] at the coordinates singleton into %dest.
//
%out = scatter %input into %dest[%indices] coordinates = [1] : 
  tensor<5x6x4x4xf32> into tensor<4x4x4xf32>[tensor<5x6x 1xindex>] 
    -> tensor<4x4x4xf32>

Only full slices are supported, if one desires partial slices one should compose with other tensor ops such as tensor.insert_slice.

An optional unique unit attribute may be specified to indicate that the coordinates are statically guaranteed to be unique at runtime. Incorrectly setting the unique attribute when the coordinates are not truly unique is undefined behavior.

No “reduction combinator” is attached to the scatter op; “scatter update” semantics are spelled as:

// %dest and %indices represent where the scatter should occur
%0 = some_value_to_scatter_into_a_dest(...)
%1 = gather_at_the_same_indices %dest[%indices]
%2 = update_compute(%0, %1)
%3 = scatter %2, %dest[%indices]

Bufferization

In the first version, the n-D gather and scatter operations must lower to an abstraction that performs copies. This is because the buffer type system is currently not rich enough to allow multiple non-contiguous views in the same type (this is also related to the following discussion and jagged buffer type section).

The issue is visible more clearly in a notional buffer version of the op:

// memref<?x4xf32> is a contiguous buffer of ?x4 elements, gather from
// random input slices must copy to the contiguous output.
%out = gather %input[%indices] coordinates = [1] : 
  memref<4x4xf32>[memref<?x 1xindex>] -> memref<?x4xf32>

// Nested buffer support would allow gather to view into the input data.
%out = gather %input[%indices] coordinates = [1] : 
  memref<4x4xf32>[memref<?x 1xindex>] -> memref<? x memref<4xf32>>

This copy requirement can be side-stepped by gradually introducing loops as discussed in ample detail (reshape case and gather case). This will allow building up the abstractions in a progressive fashion, without frontloading the need for a new buffer type while still allowing efficient lowerings.

Future Work (Outside of The Current Proposal)

Parallel extensions.

My first instinct is this seems like one of those bridges too far that could bring us to repeating mistakes of XLA gather / scatter / select_and_scatter.

I would prefer to think that we can have composable foldings on pad + gather to catch those cases.

Do you want to propose and work through some IR examples and see where failures to compose could arise?

How would the op combine two values that map to the same output index?

Is the [1] in this example referring to the dimension that the indices correspond to? Is that a parameter too?

By full slices, are you referring to slicing only along one dimension? IIUC, the way to get what tf.gather_nd does with slicing into the first N dimensions is by slicing the result of this, right?

One thing to note is that the proposal is for a notional “structured indirect load” and “structured indirect store” operations. These are semantically meaningful building blocks that are missing today.
This is conceptually much more in line with “bottom-up” thinking and closer to vector.gather / vector.scatter.

The vector op is 1-D only and has lower-level semantics; the doc says:

The semantics of the operation closely correspond to those of the `llvm.masked.scatter`

The notion of attaching a compute to the indirect store seems like a separable “top-down” concern.
I realize that the ML community has historically conflated the 2 and we need a reasonable lowering path for that.

I think the exercise is about defining n-D tensor ops that raise the level of abstraction and lower to the vector abstraction that we know connects properly to HW ISA.

ML concerns should not leak into it and should be kept separate and higher-level. In this context, it may make sense to adopt the repeated scatter indices behavior from the vector op:

The vector dialect leaves out-of-bounds and repeated index behavior
undefined. Underlying implementations may enforce strict sequential
semantics for the latter, though.

Spelling out a little more what a mixed ml_framework.scatter + update_compute could look like when lowered to tensor.gather, tensor.scatter:

%3 = some_ml_framework_scatter_with_update @update %2, %dest[%indices]: 
  tensor<?x6xf32> into tensortensor<?x6xf32>[tensor<4x?x 2xindex>]

Could lower to:

// %dest: tensor<4x?xf32> and %indices: tensor<?x6x 2xindex> represent the 
// tensor and locations into which the scatter should occur.
%0 = some_value(...): tensor<?x6xf32>

// If `%indices: tensor<?x6x 2xindex>` has repeated indices then 
// `%1: tensor<?x6xf32>` will have multiple copies of the same value.
%1 = tensor.gather %dest[%indices]: 
  tensor<4x?xf32>[tensor<?x6x 2xindex>] -> tensor<?x6xf32>

// This is a "more parallel" version without conflicts.
%2 = update(%0: tensor<?x6xf32>, %1: tensor<?x6xf32>) -> tensor<?x6xf32> 

%3, %unique_indices = histogram_compute(%2: tensor<?x6xf32>, 
                                        %indices: tensor<4x?x 2xindex>) 
  -> (tensor<?x6xf32>, tensor<?x6x 2xindex>)

// We now have unique indices here and can use the unit attr `unique` for the op.
%4 = scatter %3, %dest[%unique_indices] unique: 
  tensor<?x6xf32> into tensor<?x6xf32>[tensor<4x?x 2xindex>]

Following this line of thought, a first-class histogram op appears that would “bin” together elements whose indices are equal. This is a strawman but this is one direction this line of thought points towards.
A first class histogram op is a concept I’ve heard others mention in the past.

Yes, coordinates = [1] means “gather along” dimension 1 (i.e. take your indices from dimension 1 and keep whole slices along the rest). This is a DenseI64ArrayAttr, the invariants are:

  • the last dim of the indices tensor must match the length of the array attr (i.e. coordinates = [0, 1] <=> tensor<...x 2xindex>).
  • the dims that are not listed in coordinates carry over from the input to the output type (i.e. coordinates that are listed are “dropped”). So with coordinates = [1] we get 2x3x4 → 2x4
  • the leading dims of the output tensor are the leading dims of the indices tensor (i.e. 5x6).

I am not too sure about the tf semantics… by full slices I just mean that you cannot take a window out of the tensor. I.e. given my previous point about:

the dims that are not listed in `coordinates` carry over from the input to 
the output type (i.e. coordinates that are listed are "dropped"). 
So with `coordinates = [1]` we get `2x3x4` -> `2x4`

You could not express taking a window and get a non-full 2x2 slice; you’d have to use tensor.extract_slice for that.

I get “Access Denied” when trying to access the “Gather HLO is complex” document.

Can some part of that document be made public in order to get on the same page regarding what the codegen issues are?

You’re right, sorry, let me ask around … OTOH this was written by @sanjoyd and commented on by @raghavanr so maybe they can give a good TL;DR here if this is fresh enough in their memory?

Shouldn’t the index dims replace the coordinates that are listed? For example:

Shouldn’t the output be tensor<2x5x6x4xf32> instead? Is there any reason for always keeping the index dims as the outer dims of the input?

Thanks for clarifying that. So, IIUC, any dim that is not specified in the “coordinates” will be gathered in full.

Are these gather ops supposed to handle slicing into N dimensions (instead of 1) as done in tf.gather_nd? I’m just trying to understand how that can be done using these ops.

IIRC, the complexity with Gather HLO is that it tries to provide all possible ways to index into the input as well as all possible ways to format the output. In other words, it tries to eliminate the need for reshapes on the input and output and provides a way to specify all of those in the op itself.

The op must define some semantics for this. I’ve seen the reference to vector.scatter, but it is unclear if similar semantics are proposed here. If they are, i.e., if scatter with repeated indices is UB, then there is no need in the unique attribute because repeated indices would be UB anyway. We can still have the keyword in the syntax to make it obvious when reading the IR.

Could you please provide more details on the supported indexing schemes? Is the last dimension of the index tensor always expected to contain the coordinates, or can it be any other dimension? Must the coordinates be stored as index? Is slicing allowed along more than one dimension? If so, how the slice positions are supposed to be stored in the index tensor?

Are unranked tensors supported? (e.g., gathering from an unranked tensor into a ranked one by having a dynamic size for the last dimension).

A couple of syntactic nitpicks:

  • coordinates = [1] feels like a particularly confusing choice of name. It can be something like slice_dims(1) that clearly indicates that a slice is being taken;
  • there is little value in innovating with trailing type syntax (tensor[tensor] -> tensor) instead of just using a function type ((tensor, tensor) -> tensor).

Just as an FYI, we are experimenting with simplifying scatter and gather to reduce some of the complexity. See also the simplification pass that can be used to evaluate performance implications of this.

1 Like

Thanks for the comments and questions!

Restructuring a bit the order in my reply for a better flow.

Yes, it wasn’t clear until spelling out the “multiple indices in the same bin” case and seeing a histogram-like construct pop up. Favoring alignment between tensor and vector semantics follows established bottom-up principles (i.e. conceptually considering tensor as an n-D form with ? that will undergo bufferization). The first-principles implication of this line of thought is that the unique keyword is required for scatter and optional for gather. This creates an asymmetry—that I generally look out for as an IR design principle—but quite milder than a magical region that performs a histogram-reduction.

I think biasing towards connecting to the vector → LLVM → HW ISA path is unambiguously different from “what is good for algebraic rewrites and ML frameworks” and is what I am trying to capture as a key modular building block.

Yes, as spelled with tensors of indices, the rightmost dimension contains the n-D coordinates. I.e. the op specification uses an SoA notation. This is a convention that is expected compose with layouts (only on buffer types for now) if we want to materialize as e.g. AoS in memory or in vector registers after vectorization. The rationale is to keep the op simple and this convention seems the most intuitive. It has also worked well in the past (e.g. in how to spell convolutions): we have found it better to keep the op specification unsurprisingly simple and deal with complexity in folding “pairs of ops”. IOW, if one needs to support 3 alternatives, making the op semantics support these 3 alternatives leaks into corner cases everywhere; avoiding this leakage is the first-order error I am trying to minimize (in O-notation).

No strong opinion here, the thinking is that index → i12 types of compression should happen in a more target-specific way and later in the pipeline but if we have evidence it is not possible this should be a relatively simple extension compared to the more difficult parts of the design (agreeing on whether histogram is a building block derived from first-principles).

Yes, trying to re-spell while addressing your remark about a confusing choice of name :slight_smile: .

%out = tensor.gather %input[%indices] gather_along_dimensions(0, 2) : 
  (tensor<2x3x4xf32>, tensor<5x6x 2xindex>) -> tensor<5x6x3xf32>
// the result type could also be spelled `tensor<5x6x 1x3x1xf32>`, 
// maybe dropping the 1's too quickly is confusing ?

Proposed op convention is SoA for the indices as discussed above; materialization into storage (e.g. memory or multiple vector registers) is orthogonal and will compose.

Generally strong “not yet” from me for rank-agnostic codegen at this point.

There is a world of unranked → while + linearize / delinearize abstractions out there (this is how a lot of code in Torch-7 was implemented FWIW) but I think we are nowhere near ready to start these discussions.

There have also been ample discussions that this is separable at the frontend, even today.

Thanks, I went for gather_along_dimensions to be super explicit in this illustration but gather_dims, scatter_dims work for me too. Other suggestions welcome!

As much as I’d love to accomodate, I am afraid this one is trickier… I think I created the confusion by only giving a 1-D example here + dropping the 1s too quickly. If we take the example from this reply and just look at:

gather_dims(0, 2) : 
  (tensor<2x3x4xf32>, tensor<5x6x 2xindex>) -> tensor<5x6x1x3x1xf32>
// result type could also be tensor<5x6x3xf32> by dropping 1s.

then I am no sure how to spell a different, yet unsurprising, convention?
If we tried hard we could add another attribute to specify a permutation but that self-conflicts with avoiding mixing-in things that should compose with this abstraction (e.g. tensor.reshape, tensor.transpose etc).
Do you see a way to achieve what you want without stepping into lava ?

As Alex mentioned, my choice of notation is confusing, I hope the examples are clearer now?

Re. tf.gather_nd I don’t know (and I’d prefer not to look too deeply unless I really must :slight_smile: ).
I think this would handle N-D slices as I’d hope we want to represent them but please correct me if I am wrong.

To be clear, I certainly don’t want to be flippant about TF, XLA, ONNX, Torch and other systems that have years of experience looking at those abstractions but I believe a “common lingua franca” is super important for us to exchange in and that this lingua franca is “(pseudo) MLIR + comments”.

I’d certainly appreciate an effort to distill some of these TF abstractions to our common lingua franca and get educated with those!

1 Like

Thanks! This clarified most of my questions, please make sure to add this to the op documentation.

This feels like mostly a question of using TensorOf<Index> vs TensorOf<IntegerLike> in the op definition.

No objection, just make sure to scope it out in the op documentation.

We could have something like

gather_dims(0, 2)  index_dim(1) : 
  (tensor<2x3x4xf32>, tensor<5x 2 x6xindex>) -> tensor<5x3x6xf32>

where an extra attribute indicates which of the index tensor dimensions contains the coordinates instead of it always being the last one by convention. This would, however, keep all non-gathered dimension adjacent. To split them and achieve the form @raghavanr expected in the 1D case, we’d need a sort of inverse: indicate there the non-gathered dimensions go via some attribute

gather_dims(0, 2) output_shape[index(0), source(1), index(2)] : 
  (tensor<2x3x4xf32>, tensor<5x6x 2xindex>) -> tensor<5x3x6xf32>

At this point it feels like it duplicates reshape but with less expressive power, so I wouldn’t go there as long as it doesn’t harm further bufferization.

I believe the sort of behavior that @raghavanr is looking to emulate is driven by some ML ops where only 1 dimension is gathered along and it is injected in the non-ambiguous single place where it makes sense.

Compared to the op proposed here I believe this is at the same time less and more powerful.

The ML ops behavior can be viewed as less powerful because the proposed tensor.gather can take n-D slices along non-contiguous dimensions. While this may be more powerful than what some ML ops can support, this composes well with the existing tensor.extract_slice/insert_slice and memref.subview as well as memref.transpose. In fact, restricting tensor.gather to a more limited 1-D form runs into risks of creating many corner cases when trying to compose with other abstractions. In other words, more generic n-D slices already compose well and the proposed design is informed by this experience.

The ML ops behavior can be viewed as more powerful because a permutation is also embedded to inject the gathered dimension at the desired place by ML ops. Similarly, I believe there is good confidence based on our experience with memref.subview, memref.transpose and strided memrefs that these things compose properly with a “metadata transpose” that can be materialized by different layout choices.

Just wanted to give a perspective here from Torch-MLIR. We need efficient support for the reduction op on repeated indexes case. It’s not obvious to me how gather + update + histogram + scatter being proposed here lowers to optimal code on GPU’s, which is likely to “just be a regular scatter but use atomic read-modify-write ops instead of stores”. Maybe the GPU code path forks off while we still have the “scatter with region” abstraction level?

How is the lowering done currently for the repeated indexes case? I’m curious because we would like to figure this out for tcp.scatter as well.

If repeated_indices is true, then the scattering loop is serialized. I don’t think that will do a great job on GPU’s though unless the batch dimensions are so large that we get enough parallelism from them. I do think that tcp.scatter could be the bridge we need here.

1 Like

Thanks for your perspective @_sean_silva.

In this RFC, I left out the parallel part of the equation and I think there are 2 avenues along which rewrites could bottom out into the strategy you mention.

The first avenue is related to fusion through gather and scatter. As we create loops and take slices, the ops become lower-dimensional and “more expressive” (see the section on expressiveness vs transformation power). I can certainly see slices of (histogram + scatter) to require rewriting to RMW operations in order to allow fusion/tiling into parallel loops. It is possible that “a scatter with a histogram region” is the one true better abstraction to target such a rewrite but I don’t think we are there yet. A first step towards that IMO is getting the design right by getting tiling + inplace bufferization to work as well in the gather/scatter case as in the dense case.

The second avenue is that I see similarities between gather/scatter/parallel_scatter and extract_slice/insert_slice/parallel_insert slice that could generalize in a SubsetOpInterface. This similarity relates to the discussion on abstractions for parallelism and tensors and how to represent the reduction part. The reduction aspect is not solved yet (our friends on the XLA side have something similar with a reduction region that we may want to adopt) and I can see a parallel_scatter with a region that contains a histogram_compute like op as a way to represent some of this without losing information until after bufferization.

In any case, I think both these avenues are complementary to our ability to represent n-D gather / scatter that compose well with the lower-level abstractions that we already have. My take is we can reevaluate adding a region to the sequential n-D scatter op once the semantics, tiling/fusion and lowerings work well.