#### Summary

This RFC concerns adding a utility to the `tensor`

dialect to replace the result of a

`tensor.collapse_shape -> tensor.extract_slice`

chain with the

equivalent result formed by aggregating slices of the

tensor.collapse_shape source.

To accomplish this, one new auxiliary operation is proposed, `some_dialect.delinearize_index`

, that accepts a linear index and returns a multi-index for some specified basis.

The below text describes the problem and illustrates the proposed changes introduced to provide a solution.

#### Problem

The goal is to solve the following problem:

- You have some program consisting of operations on Tensors that looks like the following IR listing. I am using
`linalg.generic`

here to represent some operation that implements`TilingInterface`

, and some specifics are omitted to highlight the key issue.

```
// S0:
%0 = linalg.generic ins(...) outs(...) -> tensor<100x100xf32>
// S1:
%1 = linalg.generic ins (....) outs(...).... -> tensor<10x10x100xf32>
// S2:
%2 = tensor.collapse_shape %1 [[0, 1], [2]]
// S3:
%3 = linalg.matmul ins( %2, %0 : tensor<100x100xf32>, tensor<100x100xf32>)
outs(%init: tensor<100x100xf32>)
```

- Apply tiling to S3:

```
// S0:
%0 = linalg.generic ins(...) outs(...) -> tensor<100x100xf32>
// S1:
%1 = linalg.generic ins (....) outs(...).... -> tensor<10x10x100xf32>
// S2:
%2 = tensor.collapse_shape %1 [[0, 1], 2] : tensor<10x10x100xf32> into tensor<100x100xf32>
// S3:
%3 = scf.for %iv = %c0 to %c100 step %c10 iter_args(%iter_arg = %init) -> tensor<100x100xf32> {
// S4:
%0_slice = tensor.extract_slice %0 [%iv, 0] [10, 100] [1, 1] : tensor<100x100xf32> into tensor<10x100xf32>
// S5:
%2_slice = tensor.extract_slice %2 [%iv, 0] [10, 100] [1, 1] : tensor<100x100xf32> into tensor<10x100xf32>
// S6:
%4 = linalg.generic ins( %2_slice, %0_slice : tensor<100x10xf32>, tensor<100x10xf32>)
outs(%init: tensor<100x10xf32>)
....
%5 = tensor.insert_slice %4 into %iter_arg[%iv, 0][10, 100][1, 1] : tensor<10x100xf32> into tensor<100x100xf32>
scf.yield %5 : tensor<100x100xf32>
}
```

At this point you can invoke `linalg.generic`

âs implementation of `TilingInterface`

to replace S4 with a tile version of `%0`

. The same is true of any other operation that implements `TilingInterface`

. For S5, however, the `tensor.collapse_shape`

at S2 will block this from occurring even though `%2`

is also the result of a `TilingInterface`

operation.

In general, it is not possible to exchange extract_slice and collapse_shape if linearized dimensions

are sliced.

Just to be very clear, the i-th dimension of the` tensor.collapse_shape`

result is a âlinearized sliced dimensionâ if:

- Reassociation indices of tensor.collapse_shape in the iâth position is greater than size 1 (multiple dimensions of the input are collapsed).
- The i-th dimension is sliced by tensor.extract_slice.

If there are no âlinearized sliced dimensionsâ (e.g. if the extract_slice leaves the collapsed dims alone), then you can do a simple rearrangement to get the tile of %2, but I believe these optimizations are not yet implemented either.

### Potential Solutions

The solution space is basically:

- Potential Solution 1: Move the
`colalpse_shape`

. In this strategy we push/pull reshapes to the periphery of the block in order to minimize their ability to block tiling. In this case, we could âpullâ the reshape to the top of the block if we expand the`linalg.generic`

in S3 to operate on higher dims. - Potential Solution 2: You can insert IR to compute the result of S5 in terms of
`%2`

and the tiling parameters.

The first solution is implemented in `linalg`

currently (see for example this RFC from @MaheshRavishankar in February). This is probably preferable if youâre within the bounds of that strategy, but there are some situations where you might be out of its bounds:

- Itâs just not possible to move the reshape operation. For example, that could be due to hard
`linalg.generic`

representation limitations (for example, you canât âfuse inâ a collapse_shape on the output into the indexing_map of an input). It oculd be due to other limitations if the operation is something other than`linalg.generic`

. - You may not want to make the originally tiled op in S3 operate on higher dims. For example, if S3 is a matrix multiplication, and the linearized dimensions correspond to the parallelizable dimensions of one of the input operands, then you may prefer collapsing multiple parallel dimensions into a single dimension via this representation The reason may be to achieve some load-balancing effect if the tiles are distributed to threads. Similar arguments can be made if the linearized dimension are a contraction dimension.

Any combination of these situations can get you to the point where you now need Potential Solution 2.

This RFC then concerns adding a rewrite to perform Potential Solution 2 in a naive but generic manner. It introduces minimal new abstractions/ops to accomplish this.

We can accomplish this by stitching together the result of multiple `tensor.extract_slice`

âs formed by iterating over any linearized and sliced dimensions. This is equivalent to tiling the linearized dimensions by 1 or viewing the tensor.collapse_shape as a gather operation on those indices. You could also view this as a `TilingInterface`

implementation of `tensor.collapse_shape`

.

This may or may not be actually beneficial based on any number of factors, so the changes proposed really just introduce this as a targeted utility function + a test pass.

Hereâs an example of what this looks like (to replace S5 above):

```
%2_slice_init = linalg.init_tensor [10, 100] : tensor<10x100xf32>
%collapse_shape_tile = scf.for %linear_idx = %c0 to %c10 iter_args(%iter_arg = %2_slice_init) -> tensor<10x100xf32> {
// Delinearize the %linear_idx based on the dimensions of
%multi_index:3 = some_dialect.delinearize_index %linear_index %c10, %c10
%input = tensor.extract_slice %1[%multi_index#0, %multi_index#1, %iv]
// Tiled form of S1
(tensor.extract_slice's for the below omitted)
%tile_slice = linalg.generic ins(...) outs(%input) -> tensor<1x100xf32>
%tile_update = tensor.insert_slice %tile_slice into %iter_arg[%linear_idx, 0][1,100][0,0]
scf.yield %tile_update
}
```

Here the only new op introduced is `some_dialect.delienarize_index`

, which is a new op that gives the multi-index corresponding to the linearized index. This is added in the spirt of âgradual loweringâ and because it could be the subject of additional optimizations. Itâs not clear where that op should go (originally went into `arith`

, but then I was asked to revert and make this RFC.)

#### More alternatives

Concerning the actual implementation of Potential Solution 2, the second diff below adds this functionality where the iteration is accomplished by `scf.for`

or `scf.foreach_thread`

. There are other potential abstractions that could potentially replace this implementation in order to accomplish the same effect. This is an initial implementation that could be made irrelevant if additional abstractions are introduced in `tensor`

. I believe @nicolasvasilache has some plans here.

#### Diffs

Note that there was a related discussion from an RFC a little over a year ago that proposed similar linearize/delinearize index ops in the context of `memref`

dialect. There could be some set of canonicalizations that operate on such ops in order to perform targeted simplifications, but thatâs outside this use case.