[RFC][Tensor] Add a `tensor.concatenate` operation

I would like to propose adding a concatenate op to the Tensor dialect. I’ve found that when going from higher level dialects to linalg on tensors, we end up with awkward lowerings for concatenations due to the lack of a proper equivalent. Lowerings in IREE from StableHLO, Torch-MLIR, and TosaToTensor are three flavors of this, the former using a linalg.generic and tensor.extract to insert the correct values into the concatenated tensor, and the latter two using a chain of tensor.insert_slice ops.

Motivation

I’ll start by focusing on the second kind of lowering, constructing the concatenated tensor with a sequence of inserted slices taken directly from a test upstream.

func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> () { 
  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i32}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
  return
}

currently becomes (with constants dropped)

module {
  func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) {
    %dim = tensor.dim %arg0, %c0 : tensor<?x1xf32>
    %0 = tensor.empty() : tensor<5x3xf32>
    %dim_5 = tensor.dim %arg0, %c0_4 : tensor<?x1xf32>
    %inserted_slice = tensor.insert_slice %arg0 into %0[0, 0] [%dim_5, 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
    %dim_7 = tensor.dim %arg1, %c0_6 : tensor<?x1xf32>
    %inserted_slice_8 = tensor.insert_slice %arg1 into %inserted_slice[0, 1] [%dim_7, 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
    %dim_10 = tensor.dim %arg2, %c0_9 : tensor<?x1xf32>
    %inserted_slice_11 = tensor.insert_slice %arg2 into %inserted_slice_8[0, 2] [%dim_10, 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
    return
  }
}

Now let’s say that I want to transpose the inputs/outputs to this concatenation so that it happens along the outermost dimension (axis = 0). This requires analysis of the chain of tensor.insert_slice ops to ensure they all insert full slices along all dimensions except the one being concatenated, as well as that the the inserted slices fill the entire destination. As a result, making local decisions about the layout of a concatenated tensor, or trying to propagate the layout through a concatenation is much more difficult with this representation. To do the same on tosa.concat, we just need to compare the transposed indices with the concatenation axis and introduce new transposes on the inputs to the concat.

Moreover, if we later wanted to recover the lowering to a single linalg.generic shown in the above StableHLO example, this would require the same analysis.

There was a similar RFC in the past, but instead about adding it to Linalg. Here I’m suggesting it be added to tensor because:
a) Given the close relationship between operand sizes and the iteration space in Linalg operations, it has trouble representing a concatenation.
b) This seems to me like a reasonably universal operation for tensors, as evidenced by the fact that all of the dialects listed above include such an op, as well as the 'sparse_tensor' Dialect - MLIR.

Proposed representation

One option for representing it would use same representation as some of the above dialects

%cat = tensor.concatenate dimension(2) %0, %1, ... : (tensor<?x?x?xf32>, ...) into tensor<?x?x?xf32>

Where we have a concatenation along a single static dimension without a destination on a variadic number of inputs. If this needs to handle unranked tensors as well, we could allow the concatenation dim to be dynamic. The advantage of this representation is that the shape of the destination is implicit based on the shapes of the inputs. There is no need to verify that the destination of the concat along the concatenated dim has the same size as the sum of the inputs’. Additionally this is closer to the choice of representation for other tensor ops like tensor.pad 'tensor' Dialect - MLIR.

Another way to do it might be with destination passing, but that requires additional verification on the op to make sure that the inputs and destination are consistent. Also in that case it becomes possible to represent restricted versions of tensor.insert_slice with a concat.

cc @MaheshRavishankar @hanchung

1 Like

Would this op then be targeted instead in TOSA & StableHLO lowerings? Is this mostly for analysis & semantic representation while later lowering could use the other formats? I was going to inject part of an answer but figured better to just ask: why not just reuse TOSA op?

I left this out of my post above (because I didn’t want to sound like this was geared specifically for IREE), but part of the reason is to avoid having to handle Tosa, StableHLO, or Torch in later flows. IREE uses Linalg on Tensors as a convergence point and might want to make such layout decisions far after converting Tosa. I’m under the impression the same might be true for other consumers of the op as well.

Another reason was kind of hinted at above, and that would be to allow layout propagation through concatenations, re: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp. This is just one example of an analysis that does layout propagation (based on tensor.pack/unpack), but I don’t think we should be taking a TOSA dependency here.

Another example, tensor.pad even has relevance in later lowerings (see some usage of Transform Dialect - MLIR) and based on the previous Linalg RFC, the same might be true here. My understanding of the TOSA dialect is that we wouldn’t want to start adding it as a dependency in various Linalg transformations. I might have a narrow view on expected usage though.

I think having a separate lowering to Linalg (the StableHLO lowering) and a decomposition (the TOSA/Torch lowerings) would make sense, and then we could replace the TOSA lowering with this op and let the consumer decide how they want it to be handled.

Yeah, sounds like a good idea to me!

To answer why not TOSA dialect, I think the TOSA dialect is very closely tied to the TOSA spec. We probably need operation in a dialect that is more anchored on transformations at program level than tied to a spec.

While @qed mentioned what IREE does, even stepping out of IREE, I think anchoring transformations out of something that is meant as an “input” dialect allows better separation of concerns. If we need to evolve the concat semantics over time (not saying I know what it would evolve to), it would be easier to do on an operation in tensor dialect than TOSA dialect.

A tensor.concat op that one can tile and fuse through without having to perform complex analyses sounds great to me too.

It was also mentioned here, I am glad this is being picked up.

Yes this is what I figured but wanted to be clear. This is useful representation on its own independent of further lowerings and seems to simplify analysis.

The proposed representation seems consistent. I’m not sure if we’ll have many unranked cases here or if that would be more a frontend thing thats lowered by here (biasedly like the flexibility).

Sounds good, I can start executing on the above representation then.

It sounds good to me. I’ve been wanting this for a long time!

1 Like

Side note, I wouldn’t rely on this on the op design: it may be quite useful to concat two ? into a static value, subject to the traditional runtime requirements.

1 Like

I put together a first pass at the op implementation here: [mlir][tensor] Add a tensor.concat operation by qedawkins · Pull Request #72779 · llvm/llvm-project · GitHub