Authors: @nicolasvasilache (Google), @aartbik (Google)
Hello everyone,
Aart and I would like to share the results of an internal brainstorming on ideas that have been lingering for some time based on our learnings from building parts of MLIR Codegen.
There are 2 parts to this RFC:
- a thorough background analysis of existing abstractions and their current limitations
- a set of abstractions we propose to overcome these limitations.
Not all abstractions proposed here should be undertaken at the same time, these should be further sliced into separate RFCs. Still it is important to see the overall common vision in one comprehensive post.
If people think a public Google doc is a more inclusive or efficient medium for structuring comments, please let us know.
We hope it will be an exciting reading!
Context
Structured and Retargetable Code Generation (tech report) is built around dense rectangular strided subsets (a.k.a strided memref). A summary discussion about its principles is available here.
The Sparse Compiler (ACM pub, tech report) shares the computation representation layer as a target of higher level dialects and provides proper composition through a custom sparse tensor type.
This RFC proposes a vision for a progressive path forward from dense to sparse to better cover the whole spectrum of dense → ragged ->jagged → sparse types and only pay the performance price of what we use. We discuss the limitations of the current abstractions and propose a path towards the unification and better composition of these approaches (both expressiveness and transformations).
This RFC also proposes improved abstractions related to iteration and data storage that refine progressive lowering by providing a much better bridge over the semantic gap between dense / ragged / jagged / sparse tensors and their implementation.
It also comes in the context of renewed interest in making concrete progress on a tensor compute primitives dialect (RFC, ODM). We anticipate this discussion to also be relevant for this context and will help trigger the next round of unification and collaboration.
Background: N-D Compute Ops Over M-D Data: Current Limitations.
MLIR supports multiple levels of IR operating at different levels of granularity. Structured ops (e.g. matmul
, contraction
, elementwise
, transpose
, generic
etc) are compute primitives at a level of abstraction similar to “library calls on structured data”.
High-level operation semantics encode the relationship between:
- the data types,
- the iteration domain / control-flow / indexing required to traverse the data operands and
- the computational payload of the op.
This relationship between compute and data provides a flexible level of IR on which bufferization, tiling, fusion, vectorization, lowering to loops and other structured transformations can occur.
Such ops support both tensor and buffer forms, resembling:
%3 = linalg.matmul ins(%1, %2 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%0 : tensor<42x64xf32>)
linalg.matmul ins(%1, %2 : memref<42x32xf32>, memref<32x64xf32>)
outs(%0 : memref<42x64xf32>)
While this form allows a significant class of transforms and dialects to compose properly, it is also limited, by design, to the semantics that the operands types support. This comprises:
- Strided tensors, buffers and the subset operations tensor.extract_slice, tensor.insert_slice and memref.subview (in the dense case).
- The sparse_tensor dialect, types and lowerings (in the sparse case).
In this section, we discuss current limitations starting from the dense case and gradually relax the abstractions towards the sparse case.
Background: Dense Case: collapse_shape and expand_shape.
Limitations of the structured dense approach appear when composing with operations such as memref.collapse_shape
.
// Dimension collapse (i, j) ->, k -> J
%1 = memref.collapse_shape %0 [[0, 1], [2]] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
The current documentation states:
The memref.collapse_shape op produces a new view with a smaller rank whose
sizes are a reassociation of the original view. The operation is limited to such
reassociations, where subsequent, contiguous dimensions are collapsed into a
single dimension. Such reassociations never require additional allocs or copies.
Collapsing non-contiguous dimensions is undefined behavior.
The key abstraction gap is that the result of taking a subset of a memref.collapse_shape
can generally not be represented by a strided memref. As a consequence, collapse_shape
and expand_shape
often act as a boundary for transformations.
A legalization pass exists that consists in percolating collapse_shape
up and expand_shape
down, thus increasing the scope of transformations to higher-dimensional ops. This is a tradeoff and we are now looking at going beyond this limitation, more generally.
We want to allow transformations to permeate such boundaries and better explore the expressiveness / transformation tradeoff.
Background: Dense Case: Transformations of Interest
Two key transformations that one wants to apply at this level of abstraction are tiling and fusion. Note that tiling, fusion and lowering to loops are intimately related:
- lowering to loops is obtained by “tiling by 1” along all dimensions,
- fusion is performed on tiles of data (i.e. tile and fuse). In the limit, classical loop fusion can be achieved by “tiling by 1” along all dimensions and fusing.
One use case that showed up in the past and that we aim at generalizing is fusion through reshape. Let’s take the example described in fusion through reshape as a motivating example and transcribe to MLIR:
func.func @matmul_bias(%0: tensor<42x32xf32>, %1: tensor<32x64xf32>, %2: tensor<42x64xf32>,
%5: tensor<2688xf32>, %6: tensor<2688xf32>)
-> tensor<2688xf32> {
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%4 = tensor.collapse_shape %3 [[0, 1]] :
tensor<42x64xf32> into tensor<2688xf32>
%7 = linalg.elementwise "sub" ins(%4, %5: tensor<2688xf32>, tensor<2688xf32>)
outs(%6: tensor<2688xf32>)
return %7: tensor<2688xf32>
}
Assume we tile the elementwise
operation by some value %ts
, we obtain IR resembling:
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%4 = tensor.collapse_shape %3 [[0, 1]] :
tensor<42x64xf32> into tensor<2688xf32>
%7 = scf.for %i = %c0 to %N step %ts iter_args(%iter = %6: tensor<2688xf32>)
-> tensor<2688xf32> {
%a = tensor.extract_slice %4[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
%b = tensor.extract_slice %5[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
%c = tensor.extract_slice %iter[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
%res = linalg.elementwise "sub" ins(%a, %b: tensor<?xf32>) outs(%c: tensor<?xf32>)
%iter_next = tensor.insert_slice %iter[%i][%ts][1]: tensor<?xf32> into tensor<2688xf32>
scf.yield %res: tensor<2688xf32>
}
return %7: tensor<2688xf32>
It is now unclear how to even write the fusion of %3
and %4
into the loop: there are no operations and more importantly, there are no types to express taking a 1-D %ts
-sized slice out of a 2-D tensor or memref.
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %N step %ts iter_args(%iter = %6: tensor<2688xf32>)
-> tensor<2688xf32> {
// TODO: How to even express a 1-D rectangular subset of size %ts out of a 2-D rectangular tensor?
%a = tensor.extract_slice %3[...]: tensor<?xf32> from tensor<42x64xf32>
%b = tensor.extract_slice %5[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
%c = tensor.extract_slice %iter[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
%res = linalg.elementwise "sub" ins(%a, %b: tensor<?xf32>) outs(%c: tensor<?xf32>)
%iter_next = tensor.insert_slice %iter[%i][%ts][1]: tensor<?xf32> into tensor<2688xf32>
scf.yield %res: tensor<2688xf32>
}
return %7: tensor<2688xf32>
In particular, suppose we take a prime size (e.g. %ts = 5
). Depending on the value of %i
, the slice could cross row boundaries and take different number of elements form each row:
- When
%i == 0
, the extract_slice yields a slice with5 elements from row 0
and0 from row 1
. - When
%i == 40
, the extract_slice yields a slice with2 elements from row 0
and3 from row 1
.
Additionally, after bufferization, %0
could be materialized as a non-contiguous strided buffer and 5 elements would not even be in contiguous memory locations in the first place. This is not representable as a single 1-D rectangular region.
There is currently no good support for representing such an op and a data type that “gathers” elements.
Background: Dense Case: Expressiveness vs Transformation Power: Tile By Size 1 to Project Out a Non-Contiguous Dimension.
One possible approach to solving the expressiveness gap is to realize that it disappears if we “project away” the non-contiguous dimensions of the data type. This is possible by restricting the transformation power (i.e. “tiling by size 1” on the relevant collapsed multi-dimension, a.k.a lowering to loops along the collapsed multi-dimension).
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %N step %c1 iter_args(%iter = %6: tensor<2688xf32>)
-> tensor<2688xf32> {
// TODO: Since we only tile by 1 along 2688 = 42x64, we only need a tensor<1x1xf32> type
// which is always "rectangular".
// linalg.matmul can now be fused here but it becomes a 1-D linalg.sdot op.
%tmp_a = some_op %3[f(%i), g(%i)][1, 1][1, 1]: tensor<1x1xf32> from tensor<42x64xf32>
%a = tensor.collapse_shape %tmp_a [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
%b = tensor.extract_slice %5[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
%c = tensor.extract_slice %iter[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
%res = linalg.elementwise "sub" ins(%a, %b: tensor<1xf32>) outs(%c: tensor<1xf32>)
%iter_next = tensor.insert_slice %res into %iter[%i][1][1]: tensor<1xf32> into tensor<2688xf32>
scf.yield %iter_next: tensor<2688xf32>
}
return %7: tensor<2688xf32>
In other words, when navigating the expressiveness / transformation tradeoff, it is possible to restrict expressiveness to increase transformation power and vice-versa.
Additionally, we can introduce linearize
/ delinearize
abstractions to help represent this in closed form, while:
- still tiling with arbitrary tile sizes
- using the tile size 1 for bridging the expressiveness gap
In a nutshell, this approaches resembles:
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %N step %ts iter_args(%iter = %6: tensor<2688xf32>)
-> tensor<2688xf32> {
%x = scf.for %ii = %i to min(%N, %i + %ts) step %c1 iter_args(%iter2 = %iter: tensor<2688xf32>)
-> tensor<2688xf32> {
// Delinearize the linear index %ii in the 42x64 basis
// (note that this abstraction also supports dynamic SSA basis values).
%multi_index:2 = delinearize_index %ii, %c42, %c64 : index
// Since we only tile by 1 along 2688 = 42x64, we only need a tensor<1x1xf32>
// type which is always "rectangular".
// linalg.matmul can now be fused here but it becomes a 1-D `linalg.dot` op.
%tmp_a = some_op %3[%multi_index#0, %multi_index#1][1, 1][1, 1]:
tensor<1x1xf32> from tensor<42x64xf32>
%a = tensor.collapse_shape %tmp_a [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
%b = tensor.extract_slice %5[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
%c = tensor.extract_slice %iter2[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
%res = linalg.elementwise "sub" ins(%a, %b: tensor<1xf32>) outs(%c: tensor<1xf32>)
// Perform the inverse linearization.
%tmp_res = tensor.collapse_shape %res [[0, 1]] : tensor<1xf32> to tensor<1x1xf32>
// This should fold to %ii.
%lin = linearize_index %multi_index#0, %multi_index#1 %c42, %c64: index
%iter2_next = tensor.insert_slice %tmp_res into %iter2[%lin][1][1]:
tensor<1xf32> into tensor<2688xf32>
scf.yield %iter2_next: tensor<2688xf32>
}
scf.yield %x: tensor<2688xf32>
}
return %7: tensor<2688xf32>
In short, the linearize
/ delinearize
abstraction proposed above is a particular form of structured affine expressions involving modulus and division operations. Mixing these with N-D subset operations allows increasing expressiveness while still maintaining the benefits of a higher-dimensional structure. However, in the current form, a lowering (from matmul
to dot
) still needs to happen, which we aim to address with “Index Semantics Capturing Ops” (see below).
This approach is being contributed by @cbate to bring CUTLASS-style implementations of convolutions as first-class citizens in MLIR, please read his post for further details.
Background: Sparse Case: Bridge The Semantic Gap Too Quickly
The sparse compiler work is a generalization that goes much beyond the limitations of rectangular data. The underlying principle of this work is that sparsity should be a property, not a tedious implementation detail: the computational semantics are fully shared between dense and sparse tensors. Sparsity is obtained by composing existing compute ops with a properly designed sparse type. This work taught us a few valuable lessons. As we had hoped, exploiting sparsity composes directly with other transformations and the actual sparse rewriting fits the progressive lowering philosophy for MLIR really well.
Nevertheless, the lack of the availability of certain IR abstractions forces the current implementation to bridge the semantic gap too quickly between:
- Sparsity as a property of the tensor type.
- Concrete in-memory storage schemes.
- Imperative loops for co-iteration.
Furthermore, the bufferization of sparse tensors breaks the typical 1:1 mapping between tensors and buffers.
Many sparse storage formats exist that only store the nonzeros (primary storage) together with some additional information to recover the enveloping index set (secondary storage), with the objective of exploiting sparsity by only storing and computing on nonzero elements. Due to the lack of certain abstractions, we encountered two complications while trying to utilise sparsity in a progressive manner.
Sparse Case: The Opaque Pointer Workaround
First, sparse storage schemes do not follow the simple 1:1 mapping between tensors and rectangular buffers, since most of these formats consist of multiple buffers with indices, offsets, and values. In the current sparsification bufferization pass, this complication has been overcome by maintaining the 1:1 mapping between a sparse tensors and its sparse storage buffer represented by an opaque pointer into an efficient C++ data structure provided in a support library.
Although a satisfactory solution in the short term, the introduction of an opaque pointer currently is done at too high a level. In the longer term we would like to replace this mechanism with an actual mapping from sparse tensors to sparse data structure abstractions, followed by compiler visible buffers, since this will make the lowering much more progressive, and also make buffers in the resulting IR more amenable to subsequent optimizations. This would also remove the dependence on the support library.
For example, CSR storage of a sparse matrix really consists of a single array with offsets into two parallel arrays with indices and values, as illustrated below.
// CSR storage of tensor<32x64xf64, #CSR> with individual memrefs
%0 = offsets : memref<?xindex>
%1 = indices : memref<?xindex>
%2 = values : memref<?xf64>
scf.for %i = %c0 to %c32 step %c1 {
%lo = memref.load %0[%i] : memref<?xindex>
%hi = memref.load %0[%i + 1] : memref<?xindex>
scf.for %jj = %lo to %hi step %c1 {
%j = memref.load %1[%arg4] : memref<?xindex>
%val = memref.load %2[%arg4] : memref<?xf64>
…
}
}
In this representation, the IR representing indexing and the IR representing , we would like to compound the different parts in a single data structure, which can be passed around and queried as a single entity.
Background: Sparse Case: Lowering To Loops
Second, the lowering from high-level compute ops into loops currently makes a big semantic jump with a direct generation of scf.while
loops, scf.for
loops, and scf.if
statements to implement sparse co-iteration.
In the example above, the innermost loop is already more complex than the outermost loop. For proper iteration over unions or intersections of index sets, the sparse compiler needs to resort to using a complex while loop (see the dot product example later in this document).
Abstractions for defining more elaborate data structures and for iterating over general index and data sets would refine progressive lowering, with all the advantages discussed later in this doc.
Note that this is also intertwined with the previous point on the “Opaque Pointer Workaround” because more progressive lowering to better iterating constructs also requires supporting operations on subsets of indices and data. With the an opaque pointer abstraction, it means the implementation of all such supporting ops also needs to occur in the supporting library which increases the scope of IR opacity on which no further transformations are possible.
Proposal: Going Beyond Existing Abstractions: Progressively Filling the Dense - Sparse Gap.
We can summarize the previous points as: there is no progressive path from dense
to sparse
in MLIR today. Going forward, we want to (a) be able represent the whole spectrum of dense
→ ragged
→ jagged
→ sparse
→ scf
→ LLVM
, (b) allow transformations at all levels that make sense and (c) only pay the performance price of what we use.
Working from first principles, we can propose abstractions that bridge the semantic gaps discussed in the first half of this RFC.
Proposal: GatherOp and Friends
At this time, MLIR is still missing a higher-order gather
/ scatter
abstraction. From a transformation-oriented IR design perspective, the main idea is to have a representation to enumerate N-D
indexings explicitly rather than as closed-form 0-D
scalar indexing operations.
We propose to add gather
/ scatter
operations whose design is driven by first-principle considerations and that compose with the rest if MLIR. For the purposes of this doc, scatter is assumed to be the symmetrical “insert” operation to gather. Considerations related to parallelism are out of the scope of this discussion.
The gather
operation extracts a subset of the elements from an input tensor at the given indices. In its most general form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format
).
// For each 1x2 triple of coordinates in %indices, extract the element (i.e. 0-D
// subset) at the coordinates triple in %input.
// This corresponds to the COO format, enumerating triples.
// In the IR below, the indexing part of the indices tensor is isolated by a
// whitespace (e.g. we write tensor<1x2x 3xindex> instead of tensor<1x2x3xindex> to
// better isolate 3xindex).
%out = gather %input[%indices] :
tensor<4x4x4xf32>[tensor<1x2x 3xindex>] -> tensor<1x2xf32>
A slice variant is provided that allows specifying whole slices of the input tensor, to allow rectangular subsets and avoid obvious performance pits where possible.
// For each 5x6 singleton of coordinates in %indices, extract the 2-D slice [:, 1, :]
// at the coordinates singleton in %input.
%out = gather %input[%indices] coordinates = [1] :
tensor<2x3x4xf32>[tensor<5x6x 1xindex>] -> tensor<5x6x2x4xf32>
Only full slices are supported, if one desires partial slices one should compose with other tensor ops such as tensor.extract_slice. This is to avoid a slippery slope of conflation and complexity that would make the op hard to in practice (e.g. HLO gather and the related Gather HLO is Complex doc).
The gather
/ scatter
ops aim at providing a generalization safety net (i.e. sparse computations can be represented with it, albeit inefficiently) while avoiding the obvious performance gaps (i.e. if one requires taking N-D
slices, rather than enumerating all 0-D
points, then just take slices).
Fusing through Gather
Similarly as before, “tiling by size 1” allows fusing through gather operations while maintaining structure on the dimensions that are not sliced by fusion.
func.func @matmul_gather_bias(%0: tensor<42x32xf32>, %1: tensor<32x64xf32>, %2: tensor<42x64xf32>,
%indices: tensor<5x1xindex>, %5: tensor<5x64xf32>, %6: tensor<5x64xf32>)
-> tensor<5x64xf32> {
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%4 = gather %3[%indices] coordinates=[0]: tensor<42x64xf32>[tensor<5x 1xindex>] -> tensor<5x64xf32>
%7 = linalg.elementwise "sub" ins(%4, %5: tensor<5x64xf32>, tensor<5x64xf32>) outs(%6: tensor<5x64xf32>)
return %7: tensor<5x64xf32>
}
A notional fusion through gather is expected to resemble:
func.func @matmul_gather_bias(%0: tensor<42x32xf32>, %1: tensor<32x64xf32>, %2: tensor<42x64xf32>,
%indices: tensor<5x1xindex>, %5: tensor<5x64xf32>, %6: tensor<5x64xf32>)
-> tensor<5x64xf32> {
%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %c5 step %c1 iter_args(%iter = %6: tensor<5x64xf32>) -> tensor<5x64xf32> {
%index_tensor = tensor.extract_slice %indices[%i][1][1]: tensor<1xindex> from tensor<5x 1xindex>
%index = tensor.extract %index_tensor[%c0]: tensor<1xindex>
// linalg.matmul can now be fused here but it becomes a linalg.matvec.
%a = tensor.extract_slice %3[%index, 0][1, 64][1, 1]: tensor<42x64xf32> from tensor<1x64xf32>
%b = tensor.extract_slice %5[%i, 0][1, 64][1, 1]: tensor<5x64xf32> from tensor<5x64xf32>
%c = tensor.extract_slice %iter[%i, 0][1, 64][1, 1]: tensor<5x64xf32> from tensor<5x64xf32>
%d = linalg.elementwise "sub" ins(%a, %b: tensor<5x64xf32>, tensor<5x64xf32>) outs(%c: tensor<5x64xf32>)
%e = tensor.extract_slice %d into %iter[%i, 0][1, 64][1, 1]: tensor<1x64xf32> from tensor<5x64xf32>
scf.yield %e: tensor<5x64xf32>
}
return %7: tensor<5x64xf32>
}
Similarly to the “fusion through reshape”, fusing through gather induces a lowering in the rank of the structured operation (i.e. from matmul to matvec). Going beyond this “premature lowering” is discussed further down.
Proposal: ConcatOp
A concat op is a limited form of gather. It is an important enough use case to warrant a specific operation and avoid explicit enumeration of indices, even along 1
dimension.
%out = concat %t1 and %t2 along dimension = 1 :
tensor<3x4x?xf32> and tensor<3x5x?xf32> into tensor<3x9x?xf32>
A notional fusion through concat is expected to be spelled similarly as for gather with a difference in the way single-element extraction is implemented: an arith.select
operation is used to switch between %t1
and %t2
.
Similarly to the “fusion through reshape”, fusing through gather induces a lowering in the rank of the structured operation (i.e. from matmul
to matvec
). Going beyond this “premature lowering” is discussed further down.
A buffer version of concat could resemble:
%out = concat %t1 and %t2 along dimension = 1 :
memref<3x4x?xf32> and memref<3x5x?xf32> into memref<3x9x?xf32>
The actual materialization of this operation into lower level code will be subject to tradeoffs related to in-place bufferization and type expressiveness. It is worth noting that one possible lowering is through a 1-D
slice version of gather. This would provide a bootstrapping path until a better option is available.
Proposal: Bufferization, Copy, View and The Need for More Expressive Types
As currently specified, and with the existing MLIR type support, collapse_shape
, expand_shape
, gather
, scatter
, concat
operations may have to lower to memory copies. This is because the buffer type system is not rich enough to allow multiple 1-D
views in the same 2-D
type (unless it can be statically determined they are all separated by the same constant stride and can be represented as a single 2-D
strided memref).
This is visible more clearly in a notional buffer version of the op:
// memref<?x4xf32> is a contiguous buffer of ?x4 elements, gather from random input
// slices must copy to the contiguous output.
%out = gather %input[%indices] coordinates = [1] :
memref<4x4xf32>[memref<?x1xindex>] -> memref<?x4xf32>
// Nested/jagged buffer support would allow gather to return a “view” into the data.
%out = gather %input[%indices] coordinates = [1] :
memref<4x4xf32>[memref<?x1xindex>] -> memref<? x memref<4xf32>>
In other words, the same “strided memref is n-D rectangular” argument discussed in the context of “fusion though reshape” also prevents in-place bufferization.
Note that bufferization of concat
presents some special opportunities by inverting the flow of SSA values. In pseudo-IR, this resembles:
%out_buffer = some_buffer_op … : memref<3x9x?xf32>
…
%t1 = memref.subview %out_buffer[0, 0, 0][3, 4, ?][1, 1, 1]: memref<3x9x?xf32> to memref<3x4x?xf32>
%t2 = memref.subview %out_buffer[0, 4, 0][3, 5, ?][1, 1, 1]: memref<3x9x?xf32> to memref<3x5x?xf32>
// Result of concat is RAUW’ed %out_buffer
In the general case and depending how they come into existence in the IR, memref<3x4x?xf32>
and memref<3x5x?xf32>
are not guaranteed to have the same base pointer and cannot be concatenated into a single memref<3x9x?xf32>
without memory copies.
Similarly, sparse tensor concatenation and reshaping currently requires copies. In order to concatenate (same dimension), expand (to higher dimensions) or collapse (to lower dimensions) a sparse tensor, we iterate over all stored elements, compute the new indices, and insert these into freshly allocated sparse tensors. An example of collapsing a sparse matrix (2-dim) into a sparse vector (1-dim) is illustrated below in pseudo code. Clearly, here we would like to be able to more efficiently change the “view” into the sparse tensor, without having to move data around to make that happen.
iter = src->toCOO() : tensor<10x10xf32>
coo = newSparseCOO()
while (elem = iter->getNext()) {
coo->add(reshape(elem.indices), elem.value) // maps i,j to ii
}
s = newSparseTensor(coo) : tensor<100xf32>
These issues can be addressed with the addition of a first-class jagged buffer type.
Proposal: First-Class Jagged Buffer
A nested, jagged, buffer type is an N-D generalization of a 1-D jagged array. It is strictly more expressive than a strided memref and allows in-place bufferization of gather.
As a simple illustration, note how a strided memref<4 x memref<?xf32>>
is strictly more expressive than the strided memref<4x?xf32>
: the former allows 4
different 1-D
memref each with its own ? size
and base pointer
.
The in-place aspect is dependent on concrete lowerings and typically involves some runtime computation of the start address (i.e. inspector-executor style) or iterating through pointers and performing pointer copies (which is still deemed “in-place” for our purposes). Considerations re. lowering to LLVM are kept for future discussion as different alternatives come with tradeoffs that are too large to unpack here.
A similar argument is valid for concat
: it must copy when the base pointers for the underlying buffers are not known to be 2 contiguous subview
.
%out = concat %t1 and %t2 along dimension = 1 :
memref<3x4x?xf32> and memref<3x5x?xf32> into memref<3x9x?xf32>
This could be addressed by allowing a memref<3x9xmemref<?xf32>>
to mix base pointers coming from different alloc ops. This is an area subject to multiple tradeoffs and that will require more work and discussion.
With the availability of a nested buffer type, operations that previously required alloc
+ copy
to cope with either the rectangularity of strided memref
or more general sparse tensor requirements, can bufferize much more efficiently.
With such an extension, ragged and jagged tensors / buffers will be expressible as first-class citizens in the IR. In the limit, operations and types should fold and canonicalize gracefully to dense at one end of the spectrum (e.g. memref<?x?xf32>
and static pointer arithmetic), to sparse at the other end of the spectrum (e.g. memref<?xmemref<?xmemref<f32>>>
would define a proper jagged, viz. ITPACK
format for the values) and to intermediate forms that make sense.
In other words, a good high-level representation and folding patterns allow us to only pay for the price of the abstractions we use.
Basis of Confidence: Relation To Indirect Convolution and GEMM Implementations
Lifting operations to work on ragged / jagged tensors and buffers yields algorithms that generalize recent work on Indirect GEMM / Indirect Convolution (tech report). This work shows these algorithms are very competitive with traditional dense convolutions and gemm, forming a reasonable basis of confidence that even in the absence of folding all the way down to inlined address computations, the expected performance is very strong.
Proposal: Insert and Extract Subset Op Interfaces with Explicit Copy and View Behavior
The following operations have a flavor of a common notional SubsetOpInterface
: tensor.extract_slice
, tensor.insert_slice
, tensor.collapse_shape
, tensor.expand_shape
, gather
, scatter
, memref.collapse_shape
, memref.expand_shape
, memref.subview
and to some extent tensor.pad
and concat
.
It is still TBD whether a single SubsetOpInterface
or a splitting along a SubsetExtractOpInterface
, and a SubsetInsertOpInterface
is more appropriate. Details will matter as we uncover how TilingInterface
and other interfaces need to evolve to support more general types.
These abstractions should also be extended to support both copy
and view
semantics (once the types are powerful enough to support it) and improved verification for both cases. Such an extension can be achieved by simply adding a first-class copy
/ view
attribute on SubsetInsertOpInterface
and ops implementing them.
A copy
/ view
attribute extension is needed to allow expressing the bufferization of some of the proposed ops with the current type system (i.e. starting by explicitly copying is always fine) but it also provides opportunities to copy data into an allocation in fast(er) memory through bufferization.alloc_tensor
.
This is already available when e.g. packing with tensor.pad
(see section 3.2 of the tech report, for increasing spatial locality and reducing compulsory misses and cache line splitting), and should be generalized.
Such an attribute should be standardized on the ops that insert
or extract
into / from tensors, subject to further reuse and conflict analysis during bufferization (which may still decide to override the view
bit into a copy
bit to avoid memory conflicts).
This could resemble:
// Current behavior spelled out more clearly with a `view` attribute.
%r = memref.expand_shape view(%0) [[0, 1], [2]] : memref<?x?xf32> into memref<?x5x?xf32>
// New behavior not yet supported: `copy` and reshape to a faster memory.
%r = memref.expand_shape copy(%0) [[0, 1], [2]] :
memref<?x?xf32> into memref<?x5x?xf32, #strided_spec, #address_space>
// memref<?x4xf32> is a contiguous buffer of ?x4 elements, gather from random input
// slices must copy to the contiguous output.
%out = gather copy(%input[%indices]) coordinates = [1] :
memref<4x4xf32>[memref<?x1xindex>] -> memref<?x4xf32>
// Nested/jagged buffer support would allow gather to return a “view” into the data.
%out = gather view(%input[%indices]) coordinates = [1] :
memref<4x4xf32>[memref<?x1xindex>] -> memref<? x memref<4xf32>>
// This IR has strong implications: the indices *must be exactly the range*
// [%indices[0], %indices[dim(%indices) - 1]].
// We may choose to fail verification or use this as extra information: TBD.
%out = gather view(%input[%indices]) coordinates = [1] :
memref<4x4xf32>[memref<?x1xindex>] -> memref<?x4xf32>
Proposal: Control-Flow To Iterate Over More Powerful Tensor and Buffer Types.
As was validated by the sparse work, we fully expect operations that abide by structured IR principles to compose naturally with these new types and operations.
One key abstraction to support proper composition of transformations on structured ops and types is that of control-flow operations that iterate over an operand of a higher-order data type.
For now, MLIR has scf.while
and scf.for
. We will want new ops that may be better connected to the data types. In particular, we anticipate we will need constructs over index sets, motivated by requirements already found in sparse compilation but not yet formalized more generally.
Today, the sparse compiler rewrites ops to loops (a.k.a tiling by 1 in all dimensions) by using a combination of scf.while
/ scf.for
and scf.if
to express co-iteration. This rewriting bridges a semantic gap that is too large, but there is currently no other way to express the co-iterating constructs that are needed to realize sparse iteration spaces. The introduction of a higher-level set-based loop abstraction, which iterates over sets of indices (including the union
, intersection
, difference
, and complement
of sets) would preserve semantics (such as the parallel
or reduction
nature of the loop) and be a better fit to the overall progressive lowering philosophy of MLIR.
In addition, subsequent passes could provide alternative implementations of co-iteration, such as mapping the loops onto the while
/ for
/ if
implementation currently used, or alternative ways to generate the indices in a set (Knuth describes two-way merge algorithm in Chapter 5.2.4 of Volume 3 of The Art of Computer Programming
; other methods of implement the same exist as well, however, such as densely expanding one set and sparsely iteration over the other).
The following example, expressing a dot
product between two sparse 1-D
tensors illustrates the shortcoming of the current direct lowering to imperative loops of the scf
dialect.
// Linalg generic implementation of x = SUM_i a(i) * b(i) for sparse vectors a and b.
func.func @vector_dotprod(%arga: tensor<?xf64, #SparseVector>,
%argb: tensor<?xf64, #SparseVector>,
%argx: tensor<f64>) -> tensor<f64> {
%0 = linalg.generic #trait_dot
ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
outs(%argx: tensor<f64>) {
^bb(%a: f64, %b: f64, %x: f64):
%1 = arith.mulf %a, %b : f64
%2 = arith.addf %x, %1 : f64
linalg.yield %2 : f64
} -> tensor<f64>
return %0 : tensor<f64>
}
During sparsification, the compiler needs to generate a loop that co-iterates over the intersection of the index set of sparse tensor a and sparse tensor b. Lacking set iteration constructs, however, the compiler must immediately lower this to the following while loop and nested if constructs, expressed in pseudo-code below, which considerably obscures the compute semantics to subsequent transformations.
while (ia < pa1_end && ib < pb1_end) {
ia0 = a1_indices[ia]; // fetch a index
ib0 = b1_incices[ib]; // fetch b index
i = min(ia0,ib0);
if (ia0 == i && ib0 == i) { // only executes when i in both a and b
x_val += a_values[ia] * b_values[ib];
}
ia += (ia0 == i) ? 1 : 0;
ib += (ib0 == i) ? 1 : 0;
}
A set-based loop iterator would preserve the actual semantics of what is expressed above much better. In addition, it would greatly simplify the implementation of the sparsification pass itself, by implementing the required interfaces to extract (partial) slices of data along particular dimensions. The generation of all the conditionals needed for the co-iteration, can then be delegated to a later lowering pass, after all relevant canonicalizations and foldings have been applied on higher-level types.
for ((i, ia, ib) in intersection(a, b)) { // yields i in intersection and ia, ib
x_val += a_values[ia] * b_values[ib];
}
Proposal: Index Semantics Capturing Ops
To properly support such higher-order control-flow operations and bridge the gap to considerations such as efficient lowering of “fusion through reshape and gather”, we anticipate higher-order operations that express indexing semantics on whole tensors and structured sets will be useful.
We believe this will be useful to avoid committing to loops over scalar load
/ store
too early and to capture high-level IR that is transported in the proper places to enable efficient traversals to be derived (both control-flow and type access).
This design principle has served us well in the vector
dialect.
In a world where we only have 0-D
affine.apply
, affine.linearize
, affine.delinearize
and n-D
full gather
, scatter
enumeration, we will likely need ops with a region to carry the relevant information until it can be folded into static address calculations of more intrusive runtime enumerations.
// Some examples that illustrates index ops on sets with a region.
// Values from min to max incremented by step, generalizable to n-D, rectangular.
%0 = indexing.numpy_like_linspace %min, %max, %step : tensor<?xindex>
// Half-open interval returns a tensor<1xf32> as one would expect
%0 = indexing.numpy_like_linspace 42, 43, 1000 : tensor<1xindex>
// We statically know (%max - %min) / %step is 4 but not much else.
%0 = indexing.numpy_like_linspace %min, %max, %step : tensor<4xindex>
// General op with a region to represent linearize / delinearize on the whole set
// and avoid lowering to loops prematurely.
// The result is not necessarily a rectangular tensor and may require a ragged / jagged memref.
%1 = indexing.map (%a) in range(%0 : tensor <?xindex>) -> tensor<?x?xindex> {
%r0 = affine.delinearize_index %a into (%b0, %b1): index, index
return %r0#0, %r1#1 : index, index
}
// Space filling curves (used e.g. in the Ruy library) are known to induce strong locality, are
// cheap to compute and hard to recover from lower-level bit-twiddling IR.
%2 = indexing.some_hilbert_curve_based_reodering_of_indices %1 : tensor<?x?xindex>
// %1 and %2 can further be fed to a gather op that takes multiple 1-D slices
// to implement an indirect convolution algorithm via e.g. a memref<?x<memref<?xf32>>>.
// Depending on the amount of static information, this could all be unrolled and
// lowered away to avoid materializing indices in memory.
The concept of IntegerSet
is also expected to be useful in this context as it models a well-known class of higher-order sets and will properly lower to affine
abstractions in MLIR by design.
union
, intersection
, difference
, and complement
operations of sets of indices are also expected to be useful to partition and tile data e.g. for parallelization and load balancing. These should be first-class citizens operations to avoid some well-known exponential complexity blowups and the resulting impenetrable spaghetti code that are all too common with polyhedral
techniques.
General implementations may rely on type erasure and fast underlying C++ library implementations of sets and hash maps. Such precedent has been useful in the past to bring up structured and sparse codegen without having to frontload the implementation of complex low-level types in LLVM.
Future Work (Outside of The Current Proposal)
Looping constructs and abstractions for parallel and distributed extensions.
Generalized data abstractions in MLIR (lists, sets, maps).
Highly optimized LLVM representation of types.