At this time, MLIR is still missing a higher-order gather
, scatter
abstraction. This severely limits the ability to represent and reason about such primitives.
This RFC proposes to add such abstractions to the tensor dialect, driven by first-principle considerations, and that compose with other abstractions for tensors and buffers. These operations are the second embodiment of a notional SubsetOpInterface
(the first embodiment are tensor.extract_slice
, tensor.insert_slice
).
This is related to the larger discussion on dense and sparse structured codegen.
By focusing on first-principles, clean semantics and composition with existing abstractions, we want to avoid a slippery slope of conflation and complexity that would make the op extremely hard to in practice (e.g. the HLO gather and scatter and select_and_scatter. The related Gather HLO is Complex doc also contains relevant background information).
GatherOp
The gather
operation extracts a subset of the elements from an input tensor at the given indices. In its fine-grain form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format *minus* the payload, a.k.a list of coordinates
).
// For each 1x2 triple of coordinates in %indices, extract the element (i.e. 0-D
// subset) at the coordinates triple in %input.
// This corresponds to the COO format *minus* the payload (a.k.a list of coordinates),
// enumerating triples.
//
// Note: In the IR below, the indexing part of the indices tensor is isolated by a
// whitespace (e.g. we write tensor<1x2x 3xindex> instead of tensor<1x2x3xindex> to
// better isolate the coordinate part "3xindex").
//
// This corresponds to an implicit gather_along_dimension=(0, 1, 2)
%out = tensor.gather %input[%indices] :
tensor<4x4x4xf32>[tensor<1x2x 3xindex>] -> tensor<1x2xf32>
A slice variant is provided that allows specifying whole slices of the input tensor, to allow rectangular subsets and avoid obvious performance pits where possible:
// For each 5x6 singleton of coordinates in %indices, extract the 2-D slice [*, n:n+1,*]
// at the coordinates singleton in %input.
%out = tensor.gather %input[%indices] gather_along_dimension(1) :
tensor<2x3x4xf32>[tensor<5x6x 1xindex>] -> tensor<5x6x2x1x4xf32>
// Could also be spelled as:
// tensor<2x3x4xf32>[tensor<5x6x 1xindex>] -> tensor<5x6x2x1x4xf32>
Only full slices are supported, if one desires partial slices one should compose with other tensor ops such as tensor.extract_slice
.
An optional unique
unit attribute may be specified to indicate that the coordinates are statically guaranteed to be unique at runtime. Incorrectly setting the unique
attribute when the coordinates are not truly unique is undefined behavior.
ScatterOp
The scatter
operation is the symmetrical operation to gather
. It insert a subset of the elements from an input tensor into a destination tensor at the given indices. In its fine-grain form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format *minus* the payload, a.k.a list of coordinates
).
// For each 1x2 triple of coordinates in %indices, insert the
// element (i.e. 0-D subset) at the coordinates triple in %dest.
// This corresponds to implicit coordinates = [0, 1, 2]
//
%out = scatter %input into %dest[%indices] :
tensor<1x2xf32> into tensor<4x4x4xf32>[tensor<1x2x 3xindex>]
-> tensor<4x4x4xf32>
A slice variant is provided that allows specifying whole slices of the input tensor, to allow rectangular subsets and avoid obvious performance pits where possible:
// For each 5x6 singleton of coordinates in %indices, insert the 2-D
// slice [:, 1, :] at the coordinates singleton into %dest.
//
%out = scatter %input into %dest[%indices] coordinates = [1] :
tensor<5x6x4x4xf32> into tensor<4x4x4xf32>[tensor<5x6x 1xindex>]
-> tensor<4x4x4xf32>
Only full slices are supported, if one desires partial slices one should compose with other tensor ops such as tensor.insert_slice
.
An optional unique
unit attribute may be specified to indicate that the coordinates are statically guaranteed to be unique at runtime. Incorrectly setting the unique
attribute when the coordinates are not truly unique is undefined behavior.
No “reduction combinator” is attached to the scatter
op; “scatter update” semantics are spelled as:
// %dest and %indices represent where the scatter should occur
%0 = some_value_to_scatter_into_a_dest(...)
%1 = gather_at_the_same_indices %dest[%indices]
%2 = update_compute(%0, %1)
%3 = scatter %2, %dest[%indices]
Bufferization
In the first version, the n-D gather
and scatter
operations must lower to an abstraction that performs copies. This is because the buffer type system is currently not rich enough to allow multiple non-contiguous views in the same type (this is also related to the following discussion and jagged buffer type section).
The issue is visible more clearly in a notional buffer version of the op:
// memref<?x4xf32> is a contiguous buffer of ?x4 elements, gather from
// random input slices must copy to the contiguous output.
%out = gather %input[%indices] coordinates = [1] :
memref<4x4xf32>[memref<?x 1xindex>] -> memref<?x4xf32>
// Nested buffer support would allow gather to view into the input data.
%out = gather %input[%indices] coordinates = [1] :
memref<4x4xf32>[memref<?x 1xindex>] -> memref<? x memref<4xf32>>
This copy requirement can be side-stepped by gradually introducing loops as discussed in ample detail (reshape case and gather case). This will allow building up the abstractions in a progressive fashion, without frontloading the need for a new buffer type while still allowing efficient lowerings.
Future Work (Outside of The Current Proposal)
Parallel extensions.