[RFC] [Tensor] Extracting slices from `tensor.collapse_shape`


This RFC concerns adding a utility to the tensor dialect to replace the result of a
tensor.collapse_shape -> tensor.extract_slice chain with the
equivalent result formed by aggregating slices of the
tensor.collapse_shape source.

To accomplish this, one new auxiliary operation is proposed, some_dialect.delinearize_index, that accepts a linear index and returns a multi-index for some specified basis.

The below text describes the problem and illustrates the proposed changes introduced to provide a solution.


The goal is to solve the following problem:

  1. You have some program consisting of operations on Tensors that looks like the following IR listing. I am using linalg.generic here to represent some operation that implements TilingInterface, and some specifics are omitted to highlight the key issue.
// S0:
%0 = linalg.generic ins(...) outs(...) -> tensor<100x100xf32>
// S1:
%1 = linalg.generic ins (....) outs(...)....  -> tensor<10x10x100xf32>
// S2:
%2 = tensor.collapse_shape %1 [[0, 1], [2]]
// S3:
%3 = linalg.matmul ins( %2, %0 : tensor<100x100xf32>, tensor<100x100xf32>)
                               outs(%init: tensor<100x100xf32>)

  1. Apply tiling to S3:
// S0:
%0 = linalg.generic ins(...) outs(...) -> tensor<100x100xf32>
// S1:
%1 = linalg.generic ins (....) outs(...)....  -> tensor<10x10x100xf32>
// S2:
%2 = tensor.collapse_shape %1 [[0, 1], 2] : tensor<10x10x100xf32> into tensor<100x100xf32>
// S3:
%3 = scf.for %iv = %c0 to %c100 step %c10 iter_args(%iter_arg = %init) -> tensor<100x100xf32> { 
        // S4: 
       %0_slice = tensor.extract_slice %0 [%iv, 0] [10, 100] [1, 1] : tensor<100x100xf32> into tensor<10x100xf32>
        // S5:
        %2_slice = tensor.extract_slice %2  [%iv, 0] [10, 100] [1, 1] : tensor<100x100xf32> into tensor<10x100xf32>
       // S6:
       %4 = linalg.generic ins( %2_slice, %0_slice : tensor<100x10xf32>, tensor<100x10xf32>)
                               outs(%init: tensor<100x10xf32>)
       %5 = tensor.insert_slice %4 into %iter_arg[%iv, 0][10, 100][1, 1] : tensor<10x100xf32> into tensor<100x100xf32>
       scf.yield %5 : tensor<100x100xf32>

At this point you can invoke linalg.generic’s implementation of TilingInterface to replace S4 with a tile version of %0. The same is true of any other operation that implements TilingInterface. For S5, however, the tensor.collapse_shape at S2 will block this from occurring even though %2 is also the result of a TilingInterface operation.

In general, it is not possible to exchange extract_slice and collapse_shape if linearized dimensions
are sliced.

Just to be very clear, the i-th dimension of the tensor.collapse_shape
result is a “linearized sliced dimension” if:

  1. Reassociation indices of tensor.collapse_shape in the i’th position is greater than size 1 (multiple dimensions of the input are collapsed).
  2. The i-th dimension is sliced by tensor.extract_slice.

If there are no “linearized sliced dimensions” (e.g. if the extract_slice leaves the collapsed dims alone), then you can do a simple rearrangement to get the tile of %2, but I believe these optimizations are not yet implemented either.

Potential Solutions

The solution space is basically:

  • Potential Solution 1: Move the colalpse_shape. In this strategy we push/pull reshapes to the periphery of the block in order to minimize their ability to block tiling. In this case, we could “pull” the reshape to the top of the block if we expand the linalg.generic in S3 to operate on higher dims.
  • Potential Solution 2: You can insert IR to compute the result of S5 in terms of %2 and the tiling parameters.

The first solution is implemented in linalg currently (see for example this RFC from @MaheshRavishankar in February). This is probably preferable if you’re within the bounds of that strategy, but there are some situations where you might be out of its bounds:

  1. It’s just not possible to move the reshape operation. For example, that could be due to hard linalg.generic representation limitations (for example, you can’t “fuse in” a collapse_shape on the output into the indexing_map of an input). It oculd be due to other limitations if the operation is something other than linalg.generic.
  2. You may not want to make the originally tiled op in S3 operate on higher dims. For example, if S3 is a matrix multiplication, and the linearized dimensions correspond to the parallelizable dimensions of one of the input operands, then you may prefer collapsing multiple parallel dimensions into a single dimension via this representation The reason may be to achieve some load-balancing effect if the tiles are distributed to threads. Similar arguments can be made if the linearized dimension are a contraction dimension.

Any combination of these situations can get you to the point where you now need Potential Solution 2.

This RFC then concerns adding a rewrite to perform Potential Solution 2 in a naive but generic manner. It introduces minimal new abstractions/ops to accomplish this.
We can accomplish this by stitching together the result of multiple tensor.extract_slice’s formed by iterating over any linearized and sliced dimensions. This is equivalent to tiling the linearized dimensions by 1 or viewing the tensor.collapse_shape as a gather operation on those indices. You could also view this as a TilingInterface implementation of tensor.collapse_shape.

This may or may not be actually beneficial based on any number of factors, so the changes proposed really just introduce this as a targeted utility function + a test pass.

Here’s an example of what this looks like (to replace S5 above):

%2_slice_init = linalg.init_tensor [10, 100] : tensor<10x100xf32>
%collapse_shape_tile = scf.for %linear_idx = %c0 to %c10 iter_args(%iter_arg = %2_slice_init) -> tensor<10x100xf32> {
   // Delinearize the %linear_idx based on the dimensions of 
   %multi_index:3 = some_dialect.delinearize_index %linear_index %c10, %c10
   %input = tensor.extract_slice %1[%multi_index#0, %multi_index#1, %iv]
   // Tiled form of S1
   (tensor.extract_slice's for the below omitted)
   %tile_slice = linalg.generic ins(...) outs(%input) -> tensor<1x100xf32>
   %tile_update = tensor.insert_slice %tile_slice into %iter_arg[%linear_idx, 0][1,100][0,0]
   scf.yield %tile_update

Here the only new op introduced is some_dialect.delienarize_index, which is a new op that gives the multi-index corresponding to the linearized index. This is added in the spirt of “gradual lowering” and because it could be the subject of additional optimizations. It’s not clear where that op should go (originally went into arith, but then I was asked to revert and make this RFC.)

More alternatives

Concerning the actual implementation of Potential Solution 2, the second diff below adds this functionality where the iteration is accomplished by scf.for or scf.foreach_thread. There are other potential abstractions that could potentially replace this implementation in order to accomplish the same effect. This is an initial implementation that could be made irrelevant if additional abstractions are introduced in tensor. I believe @nicolasvasilache has some plans here.


  1. [mlir][Arithmetic] Add arith.delinearize_index operation

Note that there was a related discussion from an RFC a little over a year ago that proposed similar linearize/delinearize index ops in the context of memref dialect. There could be some set of canonicalizations that operate on such ops in order to perform targeted simplifications, but that’s outside this use case.

  1. [mlir][Tensor] Add rewrites to extract slices through tensor.collape_shape

Porting over some comments from Phabricator:

@River707 commented that the arith dialect is may not be the right place for the delinearize_index op, because it does not meet the typical definition of most operations in arith, which are elementwise mappable operations that can operate on different types.

I can move that operation back to the tensor dialect and change the semantics to operate on tensors.

It might be useful to decouple this operation from the specific use case it was developed under, and view them in isolation. I see the delinearize operation added here as useful beyond just the use case mentioned here. This operation, with a counter-part operation linearize that takes a list of indices and basis and generates linearized indices can be a useful way to move between multi-dimensional and one-dimensional spaces. Also noting that you can fold these two operations away if the basis match, it can be a very powerful way of reducing index arithmetic cost.

For example, structured-op based code-generation (and maybe even the path through Affine dialect) deals with multi-dimensional iteration space and multi-dimensional data object. Eventually all the data objects get collapsed into a pointer + offset. While you iteration space is multi-dimensional, You could also fold the multi-dimensional iteration space into a 1D iteration space after there is no more advantage of staying in the multi-dimensional iteration space (i.e. somewhere lower in the compilation stack). Doing this linearization of iteration space and data space typically leads.
(a) delinearizing the induction variable using some basis to recover the multi-dimensional iteration space.
(b) linearizing the induction variable again (using the same basis) into a one dimensional index for accessing data.

If all of this math is represented using affine_maps you could write a (fairly) complicated folder to recognize this chain and make it a no-op. Having the delinearize and linearize would allow you to right this folder relatively easily.
Considering the above, my suggestion is to move this into the affine dialect which is what has a lot of smarts w.r.t to index computation, and this would fit right in.

@River707 thoughts?

1 Like

The descriptions here provide much more context on what the purpose of this is (I really appreciate the writeup). I can see the reasoning and value a bit more clearly now, but it would be nice to condense some of the justification here into something that can be placed into the docs. I have a strong reluctance to this living in the arith dialect, because that dialect shouldn’t devolve into tensor/shaped index computation (that isn’t its goal). I think either affine or tensor can make sense to me, given the intended use cases described here.

– River

1 Like

+1 on the thought about not putting it in arith.

My first thought was “this would be the kind of thing that would be put in the ‘index’ dialect that we decided not to make when splitting things up”. Agreed that ‘affine’ is the closest we have right now.

1 Like

Seems like a very “affine” modeling to me as well!

All, thanks for your comments, I’ll redirect the first patch to the affine dialect.

Thanks for spelling out the use case, affine seems like a reasonable landing place for now.

+1 I expect we’ll start needing an ‘index’ dialect soonish.

Well thinking about splitting out the high-level description from the computational aspects in Shape dialect too, so perhaps yes :slight_smile: