Summary
This RFC concerns adding a utility to the tensor
dialect to replace the result of a
tensor.collapse_shape -> tensor.extract_slice
chain with the
equivalent result formed by aggregating slices of the
tensor.collapse_shape source.
To accomplish this, one new auxiliary operation is proposed, some_dialect.delinearize_index
, that accepts a linear index and returns a multi-index for some specified basis.
The below text describes the problem and illustrates the proposed changes introduced to provide a solution.
Problem
The goal is to solve the following problem:
- You have some program consisting of operations on Tensors that looks like the following IR listing. I am using
linalg.generic
here to represent some operation that implementsTilingInterface
, and some specifics are omitted to highlight the key issue.
// S0:
%0 = linalg.generic ins(...) outs(...) -> tensor<100x100xf32>
// S1:
%1 = linalg.generic ins (....) outs(...).... -> tensor<10x10x100xf32>
// S2:
%2 = tensor.collapse_shape %1 [[0, 1], [2]]
// S3:
%3 = linalg.matmul ins( %2, %0 : tensor<100x100xf32>, tensor<100x100xf32>)
outs(%init: tensor<100x100xf32>)
- Apply tiling to S3:
// S0:
%0 = linalg.generic ins(...) outs(...) -> tensor<100x100xf32>
// S1:
%1 = linalg.generic ins (....) outs(...).... -> tensor<10x10x100xf32>
// S2:
%2 = tensor.collapse_shape %1 [[0, 1], 2] : tensor<10x10x100xf32> into tensor<100x100xf32>
// S3:
%3 = scf.for %iv = %c0 to %c100 step %c10 iter_args(%iter_arg = %init) -> tensor<100x100xf32> {
// S4:
%0_slice = tensor.extract_slice %0 [%iv, 0] [10, 100] [1, 1] : tensor<100x100xf32> into tensor<10x100xf32>
// S5:
%2_slice = tensor.extract_slice %2 [%iv, 0] [10, 100] [1, 1] : tensor<100x100xf32> into tensor<10x100xf32>
// S6:
%4 = linalg.generic ins( %2_slice, %0_slice : tensor<100x10xf32>, tensor<100x10xf32>)
outs(%init: tensor<100x10xf32>)
....
%5 = tensor.insert_slice %4 into %iter_arg[%iv, 0][10, 100][1, 1] : tensor<10x100xf32> into tensor<100x100xf32>
scf.yield %5 : tensor<100x100xf32>
}
At this point you can invoke linalg.generic
âs implementation of TilingInterface
to replace S4 with a tile version of %0
. The same is true of any other operation that implements TilingInterface
. For S5, however, the tensor.collapse_shape
at S2 will block this from occurring even though %2
is also the result of a TilingInterface
operation.
In general, it is not possible to exchange extract_slice and collapse_shape if linearized dimensions
are sliced.
Just to be very clear, the i-th dimension of the tensor.collapse_shape
result is a âlinearized sliced dimensionâ if:
- Reassociation indices of tensor.collapse_shape in the iâth position is greater than size 1 (multiple dimensions of the input are collapsed).
- The i-th dimension is sliced by tensor.extract_slice.
If there are no âlinearized sliced dimensionsâ (e.g. if the extract_slice leaves the collapsed dims alone), then you can do a simple rearrangement to get the tile of %2, but I believe these optimizations are not yet implemented either.
Potential Solutions
The solution space is basically:
- Potential Solution 1: Move the
colalpse_shape
. In this strategy we push/pull reshapes to the periphery of the block in order to minimize their ability to block tiling. In this case, we could âpullâ the reshape to the top of the block if we expand thelinalg.generic
in S3 to operate on higher dims. - Potential Solution 2: You can insert IR to compute the result of S5 in terms of
%2
and the tiling parameters.
The first solution is implemented in linalg
currently (see for example this RFC from @MaheshRavishankar in February). This is probably preferable if youâre within the bounds of that strategy, but there are some situations where you might be out of its bounds:
- Itâs just not possible to move the reshape operation. For example, that could be due to hard
linalg.generic
representation limitations (for example, you canât âfuse inâ a collapse_shape on the output into the indexing_map of an input). It oculd be due to other limitations if the operation is something other thanlinalg.generic
. - You may not want to make the originally tiled op in S3 operate on higher dims. For example, if S3 is a matrix multiplication, and the linearized dimensions correspond to the parallelizable dimensions of one of the input operands, then you may prefer collapsing multiple parallel dimensions into a single dimension via this representation The reason may be to achieve some load-balancing effect if the tiles are distributed to threads. Similar arguments can be made if the linearized dimension are a contraction dimension.
Any combination of these situations can get you to the point where you now need Potential Solution 2.
This RFC then concerns adding a rewrite to perform Potential Solution 2 in a naive but generic manner. It introduces minimal new abstractions/ops to accomplish this.
We can accomplish this by stitching together the result of multiple tensor.extract_slice
âs formed by iterating over any linearized and sliced dimensions. This is equivalent to tiling the linearized dimensions by 1 or viewing the tensor.collapse_shape as a gather operation on those indices. You could also view this as a TilingInterface
implementation of tensor.collapse_shape
.
This may or may not be actually beneficial based on any number of factors, so the changes proposed really just introduce this as a targeted utility function + a test pass.
Hereâs an example of what this looks like (to replace S5 above):
%2_slice_init = linalg.init_tensor [10, 100] : tensor<10x100xf32>
%collapse_shape_tile = scf.for %linear_idx = %c0 to %c10 iter_args(%iter_arg = %2_slice_init) -> tensor<10x100xf32> {
// Delinearize the %linear_idx based on the dimensions of
%multi_index:3 = some_dialect.delinearize_index %linear_index %c10, %c10
%input = tensor.extract_slice %1[%multi_index#0, %multi_index#1, %iv]
// Tiled form of S1
(tensor.extract_slice's for the below omitted)
%tile_slice = linalg.generic ins(...) outs(%input) -> tensor<1x100xf32>
%tile_update = tensor.insert_slice %tile_slice into %iter_arg[%linear_idx, 0][1,100][0,0]
scf.yield %tile_update
}
Here the only new op introduced is some_dialect.delienarize_index
, which is a new op that gives the multi-index corresponding to the linearized index. This is added in the spirt of âgradual loweringâ and because it could be the subject of additional optimizations. Itâs not clear where that op should go (originally went into arith
, but then I was asked to revert and make this RFC.)
More alternatives
Concerning the actual implementation of Potential Solution 2, the second diff below adds this functionality where the iteration is accomplished by scf.for
or scf.foreach_thread
. There are other potential abstractions that could potentially replace this implementation in order to accomplish the same effect. This is an initial implementation that could be made irrelevant if additional abstractions are introduced in tensor
. I believe @nicolasvasilache has some plans here.
Diffs
Note that there was a related discussion from an RFC a little over a year ago that proposed similar linearize/delinearize index ops in the context of memref
dialect. There could be some set of canonicalizations that operate on such ops in order to perform targeted simplifications, but thatâs outside this use case.