I would like to propose adding a concatenate op to the Tensor dialect. I’ve found that when going from higher level dialects to linalg on tensors, we end up with awkward lowerings for concatenations due to the lack of a proper equivalent. Lowerings in IREE from StableHLO, Torch-MLIR, and TosaToTensor are three flavors of this, the former using a linalg.generic
and tensor.extract
to insert the correct values into the concatenated tensor, and the latter two using a chain of tensor.insert_slice
ops.
Motivation
I’ll start by focusing on the second kind of lowering, constructing the concatenated tensor with a sequence of inserted slices taken directly from a test upstream.
func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> () {
%0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i32}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
return
}
currently becomes (with constants dropped)
module {
func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) {
%dim = tensor.dim %arg0, %c0 : tensor<?x1xf32>
%0 = tensor.empty() : tensor<5x3xf32>
%dim_5 = tensor.dim %arg0, %c0_4 : tensor<?x1xf32>
%inserted_slice = tensor.insert_slice %arg0 into %0[0, 0] [%dim_5, 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
%dim_7 = tensor.dim %arg1, %c0_6 : tensor<?x1xf32>
%inserted_slice_8 = tensor.insert_slice %arg1 into %inserted_slice[0, 1] [%dim_7, 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
%dim_10 = tensor.dim %arg2, %c0_9 : tensor<?x1xf32>
%inserted_slice_11 = tensor.insert_slice %arg2 into %inserted_slice_8[0, 2] [%dim_10, 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
return
}
}
Now let’s say that I want to transpose the inputs/outputs to this concatenation so that it happens along the outermost dimension (axis = 0). This requires analysis of the chain of tensor.insert_slice
ops to ensure they all insert full slices along all dimensions except the one being concatenated, as well as that the the inserted slices fill the entire destination. As a result, making local decisions about the layout of a concatenated tensor, or trying to propagate the layout through a concatenation is much more difficult with this representation. To do the same on tosa.concat
, we just need to compare the transposed indices with the concatenation axis and introduce new transposes on the inputs to the concat.
Moreover, if we later wanted to recover the lowering to a single linalg.generic
shown in the above StableHLO example, this would require the same analysis.
There was a similar RFC in the past, but instead about adding it to Linalg. Here I’m suggesting it be added to tensor because:
a) Given the close relationship between operand sizes and the iteration space in Linalg operations, it has trouble representing a concatenation.
b) This seems to me like a reasonably universal operation for tensors, as evidenced by the fact that all of the dialects listed above include such an op, as well as the 'sparse_tensor' Dialect - MLIR.
Proposed representation
One option for representing it would use same representation as some of the above dialects
%cat = tensor.concatenate dimension(2) %0, %1, ... : (tensor<?x?x?xf32>, ...) into tensor<?x?x?xf32>
Where we have a concatenation along a single static dimension without a destination on a variadic number of inputs. If this needs to handle unranked tensors as well, we could allow the concatenation dim to be dynamic. The advantage of this representation is that the shape of the destination is implicit based on the shapes of the inputs. There is no need to verify that the destination of the concat along the concatenated dim has the same size as the sum of the inputs’. Additionally this is closer to the choice of representation for other tensor ops like tensor.pad
'tensor' Dialect - MLIR.
Another way to do it might be with destination passing, but that requires additional verification on the op to make sure that the inputs and destination are consistent. Also in that case it becomes possible to represent restricted versions of tensor.insert_slice
with a concat.