[RFC] Structured Codegen Beyond Rectangular Arrays

Authors: @nicolasvasilache (Google), @aartbik (Google)

Hello everyone,

Aart and I would like to share the results of an internal brainstorming on ideas that have been lingering for some time based on our learnings from building parts of MLIR Codegen.

There are 2 parts to this RFC:

  • a thorough background analysis of existing abstractions and their current limitations
  • a set of abstractions we propose to overcome these limitations.

Not all abstractions proposed here should be undertaken at the same time, these should be further sliced into separate RFCs. Still it is important to see the overall common vision in one comprehensive post.

If people think a public Google doc is a more inclusive or efficient medium for structuring comments, please let us know.

We hope it will be an exciting reading!

Context

Structured and Retargetable Code Generation (tech report) is built around dense rectangular strided subsets (a.k.a strided memref). A summary discussion about its principles is available here.

The Sparse Compiler (ACM pub, tech report) shares the computation representation layer as a target of higher level dialects and provides proper composition through a custom sparse tensor type.

This RFC proposes a vision for a progressive path forward from dense to sparse to better cover the whole spectrum of dense → ragged ->jagged → sparse types and only pay the performance price of what we use. We discuss the limitations of the current abstractions and propose a path towards the unification and better composition of these approaches (both expressiveness and transformations).

This RFC also proposes improved abstractions related to iteration and data storage that refine progressive lowering by providing a much better bridge over the semantic gap between dense / ragged / jagged / sparse tensors and their implementation.

It also comes in the context of renewed interest in making concrete progress on a tensor compute primitives dialect (RFC, ODM). We anticipate this discussion to also be relevant for this context and will help trigger the next round of unification and collaboration.

Background: N-D Compute Ops Over M-D Data: Current Limitations.

MLIR supports multiple levels of IR operating at different levels of granularity. Structured ops (e.g. matmul, contraction, elementwise, transpose, generic etc) are compute primitives at a level of abstraction similar to “library calls on structured data”.

High-level operation semantics encode the relationship between:

  1. the data types,
  2. the iteration domain / control-flow / indexing required to traverse the data operands and
  3. the computational payload of the op.

This relationship between compute and data provides a flexible level of IR on which bufferization, tiling, fusion, vectorization, lowering to loops and other structured transformations can occur.

Such ops support both tensor and buffer forms, resembling:

%3 = linalg.matmul ins(%1, %2 : tensor<42x32xf32>, tensor<32x64xf32>)
   outs(%0 : tensor<42x64xf32>)

linalg.matmul ins(%1, %2 : memref<42x32xf32>, memref<32x64xf32>)
   outs(%0 : memref<42x64xf32>)

While this form allows a significant class of transforms and dialects to compose properly, it is also limited, by design, to the semantics that the operands types support. This comprises:

  • Strided tensors, buffers and the subset operations tensor.extract_slice, tensor.insert_slice and memref.subview (in the dense case).
  • The sparse_tensor dialect, types and lowerings (in the sparse case).

In this section, we discuss current limitations starting from the dense case and gradually relax the abstractions towards the sparse case.

Background: Dense Case: collapse_shape and expand_shape.

Limitations of the structured dense approach appear when composing with operations such as memref.collapse_shape.

// Dimension collapse (i, j) ->, k -> J
%1 = memref.collapse_shape %0 [[0, 1], [2]] :
  memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>

The current documentation states:

The memref.collapse_shape op produces a new view with a smaller rank whose 
sizes are a reassociation of the original view. The operation is limited to such 
reassociations, where subsequent, contiguous dimensions are collapsed into a 
single dimension. Such reassociations never require additional allocs or copies. 
Collapsing non-contiguous dimensions is undefined behavior.

The key abstraction gap is that the result of taking a subset of a memref.collapse_shape can generally not be represented by a strided memref. As a consequence, collapse_shape and expand_shape often act as a boundary for transformations.

A legalization pass exists that consists in percolating collapse_shape up and expand_shape down, thus increasing the scope of transformations to higher-dimensional ops. This is a tradeoff and we are now looking at going beyond this limitation, more generally.

We want to allow transformations to permeate such boundaries and better explore the expressiveness / transformation tradeoff.

Background: Dense Case: Transformations of Interest

Two key transformations that one wants to apply at this level of abstraction are tiling and fusion. Note that tiling, fusion and lowering to loops are intimately related:

  • lowering to loops is obtained by “tiling by 1” along all dimensions,
  • fusion is performed on tiles of data (i.e. tile and fuse). In the limit, classical loop fusion can be achieved by “tiling by 1” along all dimensions and fusing.

One use case that showed up in the past and that we aim at generalizing is fusion through reshape. Let’s take the example described in fusion through reshape as a motivating example and transcribe to MLIR:

func.func @matmul_bias(%0: tensor<42x32xf32>, %1: tensor<32x64xf32>, %2: tensor<42x64xf32>,
                                         %5: tensor<2688xf32>, %6: tensor<2688xf32>)
  -> tensor<2688xf32> {

  %3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
    outs(%2 : tensor<42x64xf32>)
  %4 = tensor.collapse_shape %3 [[0, 1]] : 
    tensor<42x64xf32> into tensor<2688xf32>
  %7 = linalg.elementwise "sub" ins(%4, %5: tensor<2688xf32>, tensor<2688xf32>) 
    outs(%6: tensor<2688xf32>)

  return %7: tensor<2688xf32>
}

Assume we tile the elementwise operation by some value %ts, we obtain IR resembling:

%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
  outs(%2 : tensor<42x64xf32>)
%4 = tensor.collapse_shape %3 [[0, 1]] : 
  tensor<42x64xf32> into tensor<2688xf32>
%7 = scf.for %i = %c0 to %N step %ts iter_args(%iter = %6: tensor<2688xf32>) 
    -> tensor<2688xf32> {
  %a = tensor.extract_slice %4[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
  %b = tensor.extract_slice %5[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
  %c = tensor.extract_slice %iter[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
  %res = linalg.elementwise "sub" ins(%a, %b: tensor<?xf32>) outs(%c: tensor<?xf32>)
  %iter_next = tensor.insert_slice %iter[%i][%ts][1]: tensor<?xf32> into tensor<2688xf32>
  scf.yield %res: tensor<2688xf32>
}
return %7: tensor<2688xf32>

It is now unclear how to even write the fusion of %3 and %4 into the loop: there are no operations and more importantly, there are no types to express taking a 1-D %ts-sized slice out of a 2-D tensor or memref.

%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
  outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %N step %ts iter_args(%iter = %6: tensor<2688xf32>) 
    -> tensor<2688xf32> {
  // TODO: How to even express a 1-D rectangular subset of size %ts out of a 2-D rectangular tensor?
  %a = tensor.extract_slice %3[...]: tensor<?xf32> from tensor<42x64xf32>
  %b = tensor.extract_slice %5[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
  %c = tensor.extract_slice %iter[%i][%ts][1]: tensor<?xf32> from tensor<2688xf32>
  %res = linalg.elementwise "sub" ins(%a, %b: tensor<?xf32>) outs(%c: tensor<?xf32>)
  %iter_next = tensor.insert_slice %iter[%i][%ts][1]: tensor<?xf32> into tensor<2688xf32>
  scf.yield %res: tensor<2688xf32>
}
return %7: tensor<2688xf32>

In particular, suppose we take a prime size (e.g. %ts = 5). Depending on the value of %i, the slice could cross row boundaries and take different number of elements form each row:

  • When %i == 0, the extract_slice yields a slice with 5 elements from row 0 and 0 from row 1.
  • When %i == 40, the extract_slice yields a slice with 2 elements from row 0 and 3 from row 1.

Additionally, after bufferization, %0 could be materialized as a non-contiguous strided buffer and 5 elements would not even be in contiguous memory locations in the first place. This is not representable as a single 1-D rectangular region.

There is currently no good support for representing such an op and a data type that “gathers” elements.

Background: Dense Case: Expressiveness vs Transformation Power: Tile By Size 1 to Project Out a Non-Contiguous Dimension.

One possible approach to solving the expressiveness gap is to realize that it disappears if we “project away” the non-contiguous dimensions of the data type. This is possible by restricting the transformation power (i.e. “tiling by size 1” on the relevant collapsed multi-dimension, a.k.a lowering to loops along the collapsed multi-dimension).

%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
  outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %N step %c1 iter_args(%iter = %6: tensor<2688xf32>) 
    -> tensor<2688xf32> {
  // TODO: Since we only tile by 1 along 2688 = 42x64, we only need a tensor<1x1xf32> type 
  // which is always "rectangular".
  
  // linalg.matmul can now be fused here but it becomes a 1-D linalg.sdot op.
  %tmp_a = some_op %3[f(%i), g(%i)][1, 1][1, 1]: tensor<1x1xf32> from tensor<42x64xf32>
  %a = tensor.collapse_shape %tmp_a [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
  %b = tensor.extract_slice %5[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
  %c = tensor.extract_slice %iter[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
  %res = linalg.elementwise "sub" ins(%a, %b: tensor<1xf32>) outs(%c: tensor<1xf32>)
  %iter_next = tensor.insert_slice %res into %iter[%i][1][1]: tensor<1xf32> into tensor<2688xf32>
  scf.yield %iter_next: tensor<2688xf32>
}
return %7: tensor<2688xf32>

In other words, when navigating the expressiveness / transformation tradeoff, it is possible to restrict expressiveness to increase transformation power and vice-versa.

Additionally, we can introduce linearize / delinearize abstractions to help represent this in closed form, while:

  • still tiling with arbitrary tile sizes
  • using the tile size 1 for bridging the expressiveness gap

In a nutshell, this approaches resembles:

%3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
  outs(%2 : tensor<42x64xf32>)
%7 = scf.for %i = %c0 to %N step %ts iter_args(%iter = %6: tensor<2688xf32>) 
    -> tensor<2688xf32> {
  %x = scf.for %ii = %i to min(%N, %i + %ts) step %c1 iter_args(%iter2 = %iter: tensor<2688xf32>)
      -> tensor<2688xf32> {
    // Delinearize the linear index %ii in the 42x64 basis 
    // (note that this abstraction also supports dynamic SSA basis values).
    %multi_index:2 = delinearize_index %ii, %c42, %c64 : index

    // Since we only tile by 1 along 2688 = 42x64, we only need a tensor<1x1xf32> 
    // type which is always "rectangular".
    // linalg.matmul can now be fused here but it becomes a 1-D `linalg.dot` op.
    %tmp_a = some_op %3[%multi_index#0, %multi_index#1][1, 1][1, 1]: 
      tensor<1x1xf32> from tensor<42x64xf32>
    %a = tensor.collapse_shape %tmp_a [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
    %b = tensor.extract_slice %5[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
    %c = tensor.extract_slice %iter2[%i][1][1]: tensor<1xf32> from tensor<2688xf32>
    %res = linalg.elementwise "sub" ins(%a, %b: tensor<1xf32>) outs(%c: tensor<1xf32>)
    
    // Perform the inverse linearization.
    %tmp_res = tensor.collapse_shape %res [[0, 1]] : tensor<1xf32> to tensor<1x1xf32>
    // This should fold to %ii.
    %lin = linearize_index %multi_index#0, %multi_index#1 %c42, %c64: index
    %iter2_next = tensor.insert_slice %tmp_res into %iter2[%lin][1][1]: 
      tensor<1xf32> into tensor<2688xf32>
    scf.yield %iter2_next: tensor<2688xf32>
  }
  scf.yield %x: tensor<2688xf32>
}
return %7: tensor<2688xf32>

In short, the linearize / delinearize abstraction proposed above is a particular form of structured affine expressions involving modulus and division operations. Mixing these with N-D subset operations allows increasing expressiveness while still maintaining the benefits of a higher-dimensional structure. However, in the current form, a lowering (from matmul to dot) still needs to happen, which we aim to address with “Index Semantics Capturing Ops” (see below).

This approach is being contributed by @cbate to bring CUTLASS-style implementations of convolutions as first-class citizens in MLIR, please read his post for further details.

Background: Sparse Case: Bridge The Semantic Gap Too Quickly

The sparse compiler work is a generalization that goes much beyond the limitations of rectangular data. The underlying principle of this work is that sparsity should be a property, not a tedious implementation detail: the computational semantics are fully shared between dense and sparse tensors. Sparsity is obtained by composing existing compute ops with a properly designed sparse type. This work taught us a few valuable lessons. As we had hoped, exploiting sparsity composes directly with other transformations and the actual sparse rewriting fits the progressive lowering philosophy for MLIR really well.

Nevertheless, the lack of the availability of certain IR abstractions forces the current implementation to bridge the semantic gap too quickly between:

  • Sparsity as a property of the tensor type.
  • Concrete in-memory storage schemes.
  • Imperative loops for co-iteration.

Furthermore, the bufferization of sparse tensors breaks the typical 1:1 mapping between tensors and buffers.

Many sparse storage formats exist that only store the nonzeros (primary storage) together with some additional information to recover the enveloping index set (secondary storage), with the objective of exploiting sparsity by only storing and computing on nonzero elements. Due to the lack of certain abstractions, we encountered two complications while trying to utilise sparsity in a progressive manner.

Sparse Case: The Opaque Pointer Workaround

First, sparse storage schemes do not follow the simple 1:1 mapping between tensors and rectangular buffers, since most of these formats consist of multiple buffers with indices, offsets, and values. In the current sparsification bufferization pass, this complication has been overcome by maintaining the 1:1 mapping between a sparse tensors and its sparse storage buffer represented by an opaque pointer into an efficient C++ data structure provided in a support library.

Although a satisfactory solution in the short term, the introduction of an opaque pointer currently is done at too high a level. In the longer term we would like to replace this mechanism with an actual mapping from sparse tensors to sparse data structure abstractions, followed by compiler visible buffers, since this will make the lowering much more progressive, and also make buffers in the resulting IR more amenable to subsequent optimizations. This would also remove the dependence on the support library.

For example, CSR storage of a sparse matrix really consists of a single array with offsets into two parallel arrays with indices and values, as illustrated below.

// CSR storage of tensor<32x64xf64, #CSR> with individual memrefs
%0 = offsets : memref<?xindex>
%1 = indices : memref<?xindex>
%2 = values  : memref<?xf64>
scf.for %i = %c0 to %c32 step %c1 {
    %lo = memref.load %0[%i] : memref<?xindex>
    %hi = memref.load %0[%i + 1] : memref<?xindex>
    scf.for %jj = %lo to %hi step %c1 {
        %j   = memref.load %1[%arg4] : memref<?xindex>
        %val = memref.load %2[%arg4] : memref<?xf64>
        …
    }
}

In this representation, the IR representing indexing and the IR representing , we would like to compound the different parts in a single data structure, which can be passed around and queried as a single entity.

Background: Sparse Case: Lowering To Loops

Second, the lowering from high-level compute ops into loops currently makes a big semantic jump with a direct generation of scf.while loops, scf.for loops, and scf.if statements to implement sparse co-iteration.

In the example above, the innermost loop is already more complex than the outermost loop. For proper iteration over unions or intersections of index sets, the sparse compiler needs to resort to using a complex while loop (see the dot product example later in this document).

Abstractions for defining more elaborate data structures and for iterating over general index and data sets would refine progressive lowering, with all the advantages discussed later in this doc.

Note that this is also intertwined with the previous point on the “Opaque Pointer Workaround” because more progressive lowering to better iterating constructs also requires supporting operations on subsets of indices and data. With the an opaque pointer abstraction, it means the implementation of all such supporting ops also needs to occur in the supporting library which increases the scope of IR opacity on which no further transformations are possible.

Proposal: Going Beyond Existing Abstractions: Progressively Filling the Dense - Sparse Gap.

We can summarize the previous points as: there is no progressive path from dense to sparse in MLIR today. Going forward, we want to (a) be able represent the whole spectrum of denseraggedjaggedsparsescfLLVM, (b) allow transformations at all levels that make sense and (c) only pay the performance price of what we use.

Working from first principles, we can propose abstractions that bridge the semantic gaps discussed in the first half of this RFC.

Proposal: GatherOp and Friends

At this time, MLIR is still missing a higher-order gather / scatter abstraction. From a transformation-oriented IR design perspective, the main idea is to have a representation to enumerate N-D indexings explicitly rather than as closed-form 0-D scalar indexing operations.

We propose to add gather / scatter operations whose design is driven by first-principle considerations and that compose with the rest if MLIR. For the purposes of this doc, scatter is assumed to be the symmetrical “insert” operation to gather. Considerations related to parallelism are out of the scope of this discussion.

The gather operation extracts a subset of the elements from an input tensor at the given indices. In its most general form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format).

// For each 1x2 triple of coordinates in %indices, extract the element (i.e. 0-D
// subset) at the coordinates triple in %input.
// This corresponds to the COO format, enumerating triples.
// In the IR below, the indexing part of the indices tensor is isolated by a 
// whitespace (e.g. we write tensor<1x2x 3xindex> instead of tensor<1x2x3xindex> to 
// better isolate 3xindex).
%out = gather %input[%indices] : 
  tensor<4x4x4xf32>[tensor<1x2x 3xindex>] -> tensor<1x2xf32>

A slice variant is provided that allows specifying whole slices of the input tensor, to allow rectangular subsets and avoid obvious performance pits where possible.

// For each 5x6 singleton of coordinates in %indices, extract the 2-D slice [:, 1, :]
// at the coordinates singleton in %input.
%out = gather %input[%indices] coordinates = [1] : 
  tensor<2x3x4xf32>[tensor<5x6x 1xindex>] -> tensor<5x6x2x4xf32>

Only full slices are supported, if one desires partial slices one should compose with other tensor ops such as tensor.extract_slice. This is to avoid a slippery slope of conflation and complexity that would make the op hard to in practice (e.g. HLO gather and the related Gather HLO is Complex doc).

The gather / scatter ops aim at providing a generalization safety net (i.e. sparse computations can be represented with it, albeit inefficiently) while avoiding the obvious performance gaps (i.e. if one requires taking N-D slices, rather than enumerating all 0-D points, then just take slices).

Fusing through Gather

Similarly as before, “tiling by size 1” allows fusing through gather operations while maintaining structure on the dimensions that are not sliced by fusion.

func.func @matmul_gather_bias(%0: tensor<42x32xf32>, %1: tensor<32x64xf32>, %2: tensor<42x64xf32>, 
                              %indices: tensor<5x1xindex>, %5: tensor<5x64xf32>, %6: tensor<5x64xf32>)
    -> tensor<5x64xf32> {
  %3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
                    outs(%2 : tensor<42x64xf32>)
  %4 = gather %3[%indices] coordinates=[0]: tensor<42x64xf32>[tensor<5x 1xindex>] -> tensor<5x64xf32>
  %7 = linalg.elementwise "sub" ins(%4, %5: tensor<5x64xf32>, tensor<5x64xf32>) outs(%6: tensor<5x64xf32>)
  return %7: tensor<5x64xf32>
}

A notional fusion through gather is expected to resemble:

func.func @matmul_gather_bias(%0: tensor<42x32xf32>, %1: tensor<32x64xf32>, %2: tensor<42x64xf32>, 
                              %indices: tensor<5x1xindex>, %5: tensor<5x64xf32>, %6: tensor<5x64xf32>)
    -> tensor<5x64xf32> {
  %3 = linalg.matmul ins(%0, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
                    outs(%2 : tensor<42x64xf32>)
  %7 = scf.for %i = %c0 to %c5 step %c1 iter_args(%iter = %6: tensor<5x64xf32>) -> tensor<5x64xf32> {
    %index_tensor = tensor.extract_slice %indices[%i][1][1]: tensor<1xindex> from tensor<5x 1xindex>
    %index = tensor.extract %index_tensor[%c0]: tensor<1xindex>

    // linalg.matmul can now be fused here but it becomes a linalg.matvec.
    %a = tensor.extract_slice %3[%index, 0][1, 64][1, 1]: tensor<42x64xf32> from tensor<1x64xf32>
    %b = tensor.extract_slice %5[%i, 0][1, 64][1, 1]: tensor<5x64xf32> from tensor<5x64xf32>
    %c = tensor.extract_slice %iter[%i, 0][1, 64][1, 1]: tensor<5x64xf32> from tensor<5x64xf32>
    
    %d = linalg.elementwise "sub" ins(%a, %b: tensor<5x64xf32>, tensor<5x64xf32>) outs(%c: tensor<5x64xf32>)

    %e = tensor.extract_slice %d into %iter[%i, 0][1, 64][1, 1]: tensor<1x64xf32> from tensor<5x64xf32>
    scf.yield %e: tensor<5x64xf32>
  }
  return %7: tensor<5x64xf32>
}

Similarly to the “fusion through reshape”, fusing through gather induces a lowering in the rank of the structured operation (i.e. from matmul to matvec). Going beyond this “premature lowering” is discussed further down.

Proposal: ConcatOp

A concat op is a limited form of gather. It is an important enough use case to warrant a specific operation and avoid explicit enumeration of indices, even along 1 dimension.

%out = concat %t1 and %t2 along dimension = 1 :
  tensor<3x4x?xf32> and tensor<3x5x?xf32> into tensor<3x9x?xf32>

A notional fusion through concat is expected to be spelled similarly as for gather with a difference in the way single-element extraction is implemented: an arith.select operation is used to switch between %t1 and %t2.

Similarly to the “fusion through reshape”, fusing through gather induces a lowering in the rank of the structured operation (i.e. from matmul to matvec). Going beyond this “premature lowering” is discussed further down.

A buffer version of concat could resemble:

%out = concat %t1 and %t2 along dimension = 1 :
   memref<3x4x?xf32> and memref<3x5x?xf32> into memref<3x9x?xf32>

The actual materialization of this operation into lower level code will be subject to tradeoffs related to in-place bufferization and type expressiveness. It is worth noting that one possible lowering is through a 1-D slice version of gather. This would provide a bootstrapping path until a better option is available.

Proposal: Bufferization, Copy, View and The Need for More Expressive Types

As currently specified, and with the existing MLIR type support, collapse_shape, expand_shape, gather, scatter, concat operations may have to lower to memory copies. This is because the buffer type system is not rich enough to allow multiple 1-D views in the same 2-D type (unless it can be statically determined they are all separated by the same constant stride and can be represented as a single 2-D strided memref).

This is visible more clearly in a notional buffer version of the op:

// memref<?x4xf32> is a contiguous buffer of ?x4 elements, gather from random input
// slices must copy to the contiguous output.
%out = gather %input[%indices] coordinates = [1] : 
  memref<4x4xf32>[memref<?x1xindex>] -> memref<?x4xf32>

// Nested/jagged buffer support would allow gather to return a “view” into the data.
%out = gather %input[%indices] coordinates = [1] : 
  memref<4x4xf32>[memref<?x1xindex>] -> memref<? x memref<4xf32>>

In other words, the same “strided memref is n-D rectangular” argument discussed in the context of “fusion though reshape” also prevents in-place bufferization.

Note that bufferization of concat presents some special opportunities by inverting the flow of SSA values. In pseudo-IR, this resembles:

%out_buffer = some_buffer_op … : memref<3x9x?xf32>
…
%t1 = memref.subview %out_buffer[0, 0, 0][3, 4, ?][1, 1, 1]: memref<3x9x?xf32> to memref<3x4x?xf32>
%t2 = memref.subview %out_buffer[0, 4, 0][3, 5, ?][1, 1, 1]: memref<3x9x?xf32> to memref<3x5x?xf32>
// Result of concat is RAUW’ed %out_buffer

In the general case and depending how they come into existence in the IR, memref<3x4x?xf32> and memref<3x5x?xf32> are not guaranteed to have the same base pointer and cannot be concatenated into a single memref<3x9x?xf32> without memory copies.

Similarly, sparse tensor concatenation and reshaping currently requires copies. In order to concatenate (same dimension), expand (to higher dimensions) or collapse (to lower dimensions) a sparse tensor, we iterate over all stored elements, compute the new indices, and insert these into freshly allocated sparse tensors. An example of collapsing a sparse matrix (2-dim) into a sparse vector (1-dim) is illustrated below in pseudo code. Clearly, here we would like to be able to more efficiently change the “view” into the sparse tensor, without having to move data around to make that happen.

iter = src->toCOO() : tensor<10x10xf32>
coo = newSparseCOO()
while (elem = iter->getNext()) {
  coo->add(reshape(elem.indices), elem.value) // maps i,j to ii
}
s = newSparseTensor(coo) : tensor<100xf32>

These issues can be addressed with the addition of a first-class jagged buffer type.

Proposal: First-Class Jagged Buffer

A nested, jagged, buffer type is an N-D generalization of a 1-D jagged array. It is strictly more expressive than a strided memref and allows in-place bufferization of gather.

As a simple illustration, note how a strided memref<4 x memref<?xf32>> is strictly more expressive than the strided memref<4x?xf32>: the former allows 4 different 1-D memref each with its own ? size and base pointer.

The in-place aspect is dependent on concrete lowerings and typically involves some runtime computation of the start address (i.e. inspector-executor style) or iterating through pointers and performing pointer copies (which is still deemed “in-place” for our purposes). Considerations re. lowering to LLVM are kept for future discussion as different alternatives come with tradeoffs that are too large to unpack here.

A similar argument is valid for concat: it must copy when the base pointers for the underlying buffers are not known to be 2 contiguous subview.

%out = concat %t1 and %t2 along dimension = 1 : 
  memref<3x4x?xf32> and memref<3x5x?xf32> into memref<3x9x?xf32>

This could be addressed by allowing a memref<3x9xmemref<?xf32>> to mix base pointers coming from different alloc ops. This is an area subject to multiple tradeoffs and that will require more work and discussion.

With the availability of a nested buffer type, operations that previously required alloc + copy to cope with either the rectangularity of strided memref or more general sparse tensor requirements, can bufferize much more efficiently.

With such an extension, ragged and jagged tensors / buffers will be expressible as first-class citizens in the IR. In the limit, operations and types should fold and canonicalize gracefully to dense at one end of the spectrum (e.g. memref<?x?xf32> and static pointer arithmetic), to sparse at the other end of the spectrum (e.g. memref<?xmemref<?xmemref<f32>>> would define a proper jagged, viz. ITPACK format for the values) and to intermediate forms that make sense.

In other words, a good high-level representation and folding patterns allow us to only pay for the price of the abstractions we use.

Basis of Confidence: Relation To Indirect Convolution and GEMM Implementations

Lifting operations to work on ragged / jagged tensors and buffers yields algorithms that generalize recent work on Indirect GEMM / Indirect Convolution (tech report). This work shows these algorithms are very competitive with traditional dense convolutions and gemm, forming a reasonable basis of confidence that even in the absence of folding all the way down to inlined address computations, the expected performance is very strong.

Proposal: Insert and Extract Subset Op Interfaces with Explicit Copy and View Behavior

The following operations have a flavor of a common notional SubsetOpInterface: tensor.extract_slice, tensor.insert_slice, tensor.collapse_shape, tensor.expand_shape, gather, scatter, memref.collapse_shape, memref.expand_shape, memref.subview and to some extent tensor.pad and concat.

It is still TBD whether a single SubsetOpInterface or a splitting along a SubsetExtractOpInterface, and a SubsetInsertOpInterface is more appropriate. Details will matter as we uncover how TilingInterface and other interfaces need to evolve to support more general types.

These abstractions should also be extended to support both copy and view semantics (once the types are powerful enough to support it) and improved verification for both cases. Such an extension can be achieved by simply adding a first-class copy / view attribute on SubsetInsertOpInterface and ops implementing them.

A copy / view attribute extension is needed to allow expressing the bufferization of some of the proposed ops with the current type system (i.e. starting by explicitly copying is always fine) but it also provides opportunities to copy data into an allocation in fast(er) memory through bufferization.alloc_tensor.

This is already available when e.g. packing with tensor.pad (see section 3.2 of the tech report, for increasing spatial locality and reducing compulsory misses and cache line splitting), and should be generalized.

Such an attribute should be standardized on the ops that insert or extract into / from tensors, subject to further reuse and conflict analysis during bufferization (which may still decide to override the view bit into a copy bit to avoid memory conflicts).

This could resemble:

// Current behavior spelled out more clearly with a `view` attribute.
%r = memref.expand_shape view(%0) [[0, 1], [2]] : memref<?x?xf32> into memref<?x5x?xf32>

// New behavior not yet supported: `copy` and reshape to a faster memory.
%r = memref.expand_shape copy(%0) [[0, 1], [2]] : 
  memref<?x?xf32> into memref<?x5x?xf32, #strided_spec, #address_space>

// memref<?x4xf32> is a contiguous buffer of ?x4 elements, gather from random input
// slices must copy to the contiguous output.
%out = gather copy(%input[%indices]) coordinates = [1] : 
  memref<4x4xf32>[memref<?x1xindex>] -> memref<?x4xf32>

// Nested/jagged buffer support would allow gather to return a “view” into the data.
%out = gather view(%input[%indices]) coordinates = [1] : 
  memref<4x4xf32>[memref<?x1xindex>] -> memref<? x memref<4xf32>>

// This IR has strong implications: the indices *must be exactly the range*
// [%indices[0], %indices[dim(%indices) - 1]].
// We may choose to fail verification or use this as extra information: TBD.
%out = gather view(%input[%indices]) coordinates = [1] : 
  memref<4x4xf32>[memref<?x1xindex>] -> memref<?x4xf32>

Proposal: Control-Flow To Iterate Over More Powerful Tensor and Buffer Types.

As was validated by the sparse work, we fully expect operations that abide by structured IR principles to compose naturally with these new types and operations.

One key abstraction to support proper composition of transformations on structured ops and types is that of control-flow operations that iterate over an operand of a higher-order data type.

For now, MLIR has scf.while and scf.for. We will want new ops that may be better connected to the data types. In particular, we anticipate we will need constructs over index sets, motivated by requirements already found in sparse compilation but not yet formalized more generally.

Today, the sparse compiler rewrites ops to loops (a.k.a tiling by 1 in all dimensions) by using a combination of scf.while / scf.for and scf.if to express co-iteration. This rewriting bridges a semantic gap that is too large, but there is currently no other way to express the co-iterating constructs that are needed to realize sparse iteration spaces. The introduction of a higher-level set-based loop abstraction, which iterates over sets of indices (including the union, intersection, difference, and complement of sets) would preserve semantics (such as the parallel or reduction nature of the loop) and be a better fit to the overall progressive lowering philosophy of MLIR.

In addition, subsequent passes could provide alternative implementations of co-iteration, such as mapping the loops onto the while / for / if implementation currently used, or alternative ways to generate the indices in a set (Knuth describes two-way merge algorithm in Chapter 5.2.4 of Volume 3 of The Art of Computer Programming; other methods of implement the same exist as well, however, such as densely expanding one set and sparsely iteration over the other).

The following example, expressing a dot product between two sparse 1-D tensors illustrates the shortcoming of the current direct lowering to imperative loops of the scf dialect.

// Linalg generic implementation of x = SUM_i a(i) * b(i) for sparse vectors a and b.
func.func @vector_dotprod(%arga: tensor<?xf64, #SparseVector>,
                          %argb: tensor<?xf64, #SparseVector>,
                          %argx: tensor<f64>) -> tensor<f64> {
  %0 = linalg.generic #trait_dot
     ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
      outs(%argx: tensor<f64>) {
      ^bb(%a: f64, %b: f64, %x: f64):
        %1 = arith.mulf %a, %b : f64
        %2 = arith.addf %x, %1 : f64
        linalg.yield %2 : f64
  } -> tensor<f64>
  return %0 : tensor<f64>
}

During sparsification, the compiler needs to generate a loop that co-iterates over the intersection of the index set of sparse tensor a and sparse tensor b. Lacking set iteration constructs, however, the compiler must immediately lower this to the following while loop and nested if constructs, expressed in pseudo-code below, which considerably obscures the compute semantics to subsequent transformations.

while (ia < pa1_end && ib < pb1_end) {
  ia0 = a1_indices[ia]; // fetch a index
  ib0 = b1_incices[ib]; // fetch b index
  i = min(ia0,ib0);
  if (ia0 == i && ib0 == i) { // only executes when i in both a and b
    x_val += a_values[ia] * b_values[ib];
  }
  ia += (ia0 == i) ? 1 : 0;
  ib += (ib0 == i) ? 1 : 0;
}

A set-based loop iterator would preserve the actual semantics of what is expressed above much better. In addition, it would greatly simplify the implementation of the sparsification pass itself, by implementing the required interfaces to extract (partial) slices of data along particular dimensions. The generation of all the conditionals needed for the co-iteration, can then be delegated to a later lowering pass, after all relevant canonicalizations and foldings have been applied on higher-level types.

for ((i, ia, ib) in intersection(a, b)) {  // yields i in intersection and ia, ib
  x_val += a_values[ia] * b_values[ib];
}

Proposal: Index Semantics Capturing Ops

To properly support such higher-order control-flow operations and bridge the gap to considerations such as efficient lowering of “fusion through reshape and gather”, we anticipate higher-order operations that express indexing semantics on whole tensors and structured sets will be useful.

We believe this will be useful to avoid committing to loops over scalar load / store too early and to capture high-level IR that is transported in the proper places to enable efficient traversals to be derived (both control-flow and type access).

This design principle has served us well in the vector dialect.

In a world where we only have 0-D affine.apply, affine.linearize, affine.delinearize and n-D full gather, scatter enumeration, we will likely need ops with a region to carry the relevant information until it can be folded into static address calculations of more intrusive runtime enumerations.

// Some examples that illustrates index ops on sets with a region.
// Values from min to max incremented by step, generalizable to n-D, rectangular.
%0 = indexing.numpy_like_linspace %min, %max, %step : tensor<?xindex>
// Half-open interval returns a tensor<1xf32> as one would expect
%0 = indexing.numpy_like_linspace 42, 43, 1000 : tensor<1xindex> 
// We statically know (%max - %min) / %step is 4 but not much else.
%0 = indexing.numpy_like_linspace %min, %max, %step : tensor<4xindex> 

// General op with a region to represent linearize / delinearize on the whole set 
// and avoid lowering to loops prematurely.
// The result is not necessarily a rectangular tensor and may require a ragged / jagged memref.
%1 = indexing.map (%a) in range(%0 : tensor <?xindex>) -> tensor<?x?xindex> {
  %r0 = affine.delinearize_index %a into (%b0, %b1): index, index
  return %r0#0, %r1#1 : index, index 
}

// Space filling curves (used e.g. in the Ruy library) are known to induce strong locality, are
// cheap to compute and hard to recover from lower-level bit-twiddling IR.
%2 = indexing.some_hilbert_curve_based_reodering_of_indices %1 : tensor<?x?xindex>

// %1 and %2 can further be fed to a gather op that takes multiple 1-D slices 
// to implement an indirect convolution algorithm via e.g. a memref<?x<memref<?xf32>>>.
// Depending on the amount of static information, this could all be unrolled and 
// lowered away to avoid materializing indices in memory.

The concept of IntegerSet is also expected to be useful in this context as it models a well-known class of higher-order sets and will properly lower to affine abstractions in MLIR by design.

union, intersection, difference, and complement operations of sets of indices are also expected to be useful to partition and tile data e.g. for parallelization and load balancing. These should be first-class citizens operations to avoid some well-known exponential complexity blowups and the resulting impenetrable spaghetti code that are all too common with polyhedral techniques.

General implementations may rely on type erasure and fast underlying C++ library implementations of sets and hash maps. Such precedent has been useful in the past to bring up structured and sparse codegen without having to frontload the implementation of complex low-level types in LLVM.

Future Work (Outside of The Current Proposal)

Looping constructs and abstractions for parallel and distributed extensions.

Generalized data abstractions in MLIR (lists, sets, maps).

Highly optimized LLVM representation of types.

7 Likes

Can gather allow for an optional out-of-bounds access mode? For example, this could be utilized to abstract away gathering on the result of a pad operation.

I have spliced out an RFC specifically for gather / scatter:

I’ll reply there.

Only providing vague thoughts on the location of code like this…

I think this proposal is trying to allow users of this dense → ragged → sparse type-lowering to be as much a first class citizen as existing use cases. Would these efforts impose any additional work or struggles upon users of existing supported situations?

Similarly, this proposal wants gather/scatter to “compose with the rest of MLIR,” so placing it somewhere like Tensor dialect makes sense, but it also provides more functionality than most existing ops there such as fusion, expands the existing SubsetOpInterface, and interacts with new indexing control-flow operations mentioned. These all remind me of recent discussion with the ML related transformations dialect or expanding the linalg dialect in that there are many related transformations that seem to occur. Given these multiple desired transformations, I see benefits to this being like the bufferization dialect to represent these type transformations and centering on that dialect to handle transformations like the other recent transformation-centered dialect conversations. Are you wanting to avoid something like a separate dialect for this, and could you describe why if so?

This is where breaking up pieces of this broader vision into semantically meaningful and composable units is going to be important. IOW:

This is precisely why the abstractions need to be sliced and evaluated independently, while still keeping track of the general “north star”. Gather / scatter ops are one thing, transformations and rewrites are a separate thing, functionality is a third thing etc.

Yes and the MLIR way has been to factor out common properties and principles into the right interfaces. I expect this trend to continue and to gradually start incorporating more irregular (but still structured) cases than “dense rectangular”.

Yes, Gather / Scatter will implement BufferizableOpInterface. The other transformations will use other interfaces (e.g. TilingInterface). This is part of composing with the rest of MLIR.

Providing a guiding path towards where we are interested in evolving some of the structured abstractions is very different from proposing to cram everything in a new dialect: putting everything in the same place just for the sake of is undesirable. Instead, ops and interfaces with well-defined semantics that compose with the rest is what we should aspire to.

Of those, the indexing abstractions could potentially warrant in a new dialect, as this notion of an indexing dialect has surfaced a bunch of times in different contexts over the years. Historically, the fact that common abstractions emerge, from first principles, from different parts of the system has carried a strong value signal.

Indexing abstractions could also make sense as part of the sparse_tensor dialect. This should be evaluated once we have more data and arrive at a point where we are ready to make a concrete proposal.

Thanks for writing all this up!

There is a lot to unpack here and I suppose most of these things will be sent as separate more concrete proposals. It would be very helpful to have a short “design philosophy” paragraph that can be used to check if future proposals fit into the structured codegen vision. Especially if we expect more people to contribute proposals. There is the underlying notion of capturing structure, but there seems to be multiple IR entities that can be used: new operations (purely control flow or combined control flow and compute), new/customized types and some mix of both. Any general clues on when to introduce those and how would be great.

I have a couple of design questions, specifically about memref-of-memref vs. a more specialized ragged array type, and the overall concept of inplacedness, but it would probably make sense to discuss them on more concrete proposals.

+1

I’d say that I am also keenly interested in the development process here. We’re currently in something of an “island of stability” design wise on this stuff and I think that is a good state to be in (what these parts of the system do today is really important). I’d like to see us invest in these directions, but I think we should find a way to do so that encourages experimentation and critique outside of the upstream codebase (fork, incubator, etc – not going to suggest a method here). Ideally, this process produces something concrete that we can look at and say unambiguously “that fits there”. Also, doing that work in a bit more of a disconnected fashion will further drive evolution of some of the core types and layering that would currently bias such work to happen in-situ in the upstream codebase.

1 Like

Calling out two sparsity specific issues here: concrete in-memory storage schemes and set-based iteration.

To address the most important IR opacity issue, the sparse compiler team is planning to replace the conversion of sparse tensors and primitives to opaque pointers and calls into a runtime support library with actual code generation into buffers and actual IR (see D132766 for a very tiny first step on this journey). The use of a library certainly increased velocity of execution while ramping up sparse compilation, especially since MLIR is not very suitable yet for setting up elaborate data structures. For many reasons, however, the time has come to consider generating actual code, and to partially or even completely eliminate the dependence on a support library. For this to be successful, support beyond rectangular arrays will be essential.

Next, set-based iterative constructs would greatly refine the ability for progressive lowering during “sparsification”. Although lower priority than the first issue in this post, in the long term, supporting this construct would contribute much to the code health of the sparse compiler, which will then be able lower code much more gradually, with all obvious advantages.

2 Likes